From 89bb4c5532665e1abb7aef42f5e9711b10facc9e Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 18 Jan 2023 09:58:56 -0600 Subject: [PATCH 01/11] Implement adam optimizer cuda kernel --- src/optim/adam/adam.cu | 53 +++++++++++++++++++++++++++++ src/optim/adam/cpu_kernel.rs | 4 ++- src/optim/adam/cuda_kernel.rs | 64 +++++++++++++++++++++++++++++++++-- src/optim/adam/mod.rs | 11 +++--- src/optim/optimizer.rs | 38 +++++++++++++++++++++ 5 files changed, 161 insertions(+), 9 deletions(-) create mode 100644 src/optim/adam/adam.cu diff --git a/src/optim/adam/adam.cu b/src/optim/adam/adam.cu new file mode 100644 index 000000000..cddc5baf0 --- /dev/null +++ b/src/optim/adam/adam.cu @@ -0,0 +1,53 @@ +enum WeightDecayType { + None, + L2, + Decoupled +}; + +struct AdamConfig { + float lr; + float beta1; + float beta2; + float eps; + WeightDecayType weight_decay_type; + float weight_decay; +}; + +extern "C" __global__ void adam_update( + const AdamConfig cfg, + const size_t numel, + const float t, + float* param, + float* moment1, + float* moment2, + const float* grad +) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i >= numel) { + return; + } + + float p = param[i]; + float g = grad[i]; + float m = moment1[i]; + float v = moment2[i]; + + if (cfg.weight_decay_type == L2) { + g += cfg.weight_decay * p; + } + + m = m * cfg.beta1 + g * (1.0 - cfg.beta1); + v = v * cfg.beta2 + g * g * (1.0 - cfg.beta2); + float m_hat = m * 1.0 / (1.0 - powf(cfg.beta1, t)); + float v_hat = v * 1.0 / (1.0 - powf(cfg.beta2, t)); + g = cfg.lr * m_hat / (sqrtf(v_hat) + cfg.eps); + + if (cfg.weight_decay_type == Decoupled) { + g += cfg.weight_decay * cfg.lr * p; + } + + moment1[i] = m; + moment2[i] = v; + param[i] -= g; +} diff --git a/src/optim/adam/cpu_kernel.rs b/src/optim/adam/cpu_kernel.rs index 5e8f28531..25140c2db 100644 --- a/src/optim/adam/cpu_kernel.rs +++ b/src/optim/adam/cpu_kernel.rs @@ -3,13 +3,14 @@ use crate::{optim::WeightDecay, shapes::Shape, tensor::Cpu}; impl AdamKernel for Cpu { fn update( + &self, t: i32, cfg: &AdamConfig, param: &mut Self::Storage, moment1: &mut Self::Storage, moment2: &mut Self::Storage, grad: Self::Storage, - ) { + ) -> Result<(), Self::Err> { debug_assert_eq!(param.data.len(), grad.data.len()); debug_assert_eq!(param.shape, grad.shape); debug_assert_eq!(param.strides, grad.strides); @@ -35,5 +36,6 @@ impl AdamKernel for Cpu { *p -= g; } + Ok(()) } } diff --git a/src/optim/adam/cuda_kernel.rs b/src/optim/adam/cuda_kernel.rs index e77a585cf..9384a9626 100644 --- a/src/optim/adam/cuda_kernel.rs +++ b/src/optim/adam/cuda_kernel.rs @@ -1,14 +1,72 @@ +use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig}; +use std::sync::Arc; use crate::{shapes::Shape, tensor::Cuda}; +use super::AdamConfig; +use crate::optim::optimizer::*; + +#[repr(C)] +struct CudaAdamConfig { + lr: E, + beta1: E, + beta2: E, + eps: E, + weight_decay_type: WeightDecayType, + weight_decay: E, +} + +unsafe impl AsKernelParam for CudaAdamConfig {} + +fn adam_config_to_cuda(config: &AdamConfig) -> CudaAdamConfig { + let (weight_decay_type, weight_decay) = weight_decay_to_cuda(config.weight_decay); + + CudaAdamConfig { + lr: config.lr, + beta1: config.betas[0], + beta2: config.betas[1], + eps: config.eps, + weight_decay_type, + weight_decay + } +} + +const MODULE_NAME: &str = "adam"; +const FN_NAME: &str = "adam_update"; +const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/adam.ptx")); impl super::AdamKernel for Cuda { fn update( + &self, t: i32, - cfg: &super::AdamConfig, + cfg: &AdamConfig, param: &mut Self::Storage, moment1: &mut Self::Storage, moment2: &mut Self::Storage, grad: Self::Storage, - ) { - todo!() + ) -> Result<(), Self::Err> { + debug_assert_eq!(param.data.len(), grad.data.len()); + debug_assert_eq!(param.shape, grad.shape); + debug_assert_eq!(param.strides, grad.strides); + + if !self.dev.has_func(MODULE_NAME, FN_NAME) { + self.dev + .load_ptx(PTX_SRC.into(), MODULE_NAME, &[FN_NAME])?; + } + + let adam_cfg = adam_config_to_cuda(cfg); + let numel = param.shape.num_elements(); + + let func = self.dev.get_func(MODULE_NAME, FN_NAME).unwrap(); + let cfg = LaunchConfig::for_num_elems(numel as u32); + let params = ( + adam_cfg, // const AdamConfig cfg, + numel, // const size_t numel, + t as f32, // const float t, + Arc::make_mut(&mut param.data), // float* param, + Arc::make_mut(&mut moment1.data), // float* moment1, + Arc::make_mut(&mut moment2.data), // float* moment2, + grad.data.as_ref(), // const float* grad + ); + unsafe { func.launch_async(cfg, params) }?; + Ok(()) } } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index a60a6f943..5701ca715 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -115,13 +115,14 @@ impl Adam { pub(super) trait AdamKernel: DeviceStorage { fn update( + &self, t: i32, cfg: &AdamConfig, param: &mut Self::Storage, moment1: &mut Self::Storage, moment2: &mut Self::Storage, grad: Self::Storage, - ); + ) -> Result<(), Self::Err>; } impl, E: Dtype> ParamUpdater for Adam { @@ -136,7 +137,7 @@ impl, E: Dtype> ParamUpdater for Adam< Some(g) => { let m_t = self.moment1.get_or_alloc_mut(p)?; let v_t = self.moment2.get_or_alloc_mut(p)?; - D::update(self.t, &self.cfg, &mut p.storage, m_t, v_t, g); + p.device.update(self.t, &self.cfg, &mut p.storage, m_t, v_t, g)?; } } Ok(()) @@ -221,7 +222,7 @@ mod tests { for e in expected.iter() { let gradients = (t.trace() * rate.clone()).square().mean().backward(); opt.update(&mut t, gradients).expect(""); - assert_eq!(&t.array(), e); + assert_close(&t.array(), e); } } @@ -251,7 +252,7 @@ mod tests { for e in expected.iter() { let gradients = t.trace().exp().square().mean().backward(); opt.update(&mut t, gradients).expect(""); - assert_eq!(&t.array(), e); + assert_close(&t.array(), e); } } @@ -281,7 +282,7 @@ mod tests { for e in expected.iter() { let gradients = t.trace().exp().square().mean().backward(); opt.update(&mut t, gradients).expect(""); - assert_eq!(&t.array(), e); + assert_close(&t.array(), e); } } diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index c2f5c7056..c24143290 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -16,6 +16,25 @@ pub enum WeightDecay { Decoupled(E), } +/// Used to communicate the "WeightDecay" enum to cuda kernels +#[cfg(feature = "cuda")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub(super) enum WeightDecayType { + None, + L2, + Decoupled, +} + +#[cfg(feature = "cuda")] +pub(super) fn weight_decay_to_cuda(wd: Option>) -> (WeightDecayType, E) { + match wd { + None => (WeightDecayType::None, Default::default()), + Some(WeightDecay::L2(x)) => (WeightDecayType::L2, x), + Some(WeightDecay::Decoupled(x)) => (WeightDecayType::Decoupled, x), + } +} + /// Momentum used for [super::Sgd] and others #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Momentum { @@ -26,6 +45,25 @@ pub enum Momentum { Nesterov(E), } +/// Used to communicate the "Momentum" enum to cuda kernels +#[cfg(feature = "cuda")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub(super) enum MomentumType { + None, + Classic, + Nesterov, +} + +#[cfg(feature = "cuda")] +pub(super) fn momentum_to_cuda(wd: Option>) -> (MomentumType, E) { + match wd { + None => (MomentumType::None, Default::default()), + Some(Momentum::Classic(x)) => (MomentumType::Classic, x), + Some(Momentum::Nesterov(x)) => (MomentumType::Nesterov, x), + } +} + /// All optimizers must implement the update function, which takes an object /// that implements [GradientUpdate], and calls [GradientUpdate::update]. /// From baf7fef4195161c6c053d8d023a46d92b1feae88 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Thu, 19 Jan 2023 22:03:37 -0600 Subject: [PATCH 02/11] add chunk_sum to sum cuda kernel and modify supporting code to use it --- src/tensor_ops/reductions.txt | 31 ++++++++++++ src/tensor_ops/sum_to/cuda_kernel.rs | 28 +++++++---- src/tensor_ops/sum_to/sum_to.cu | 71 +++++++++++++++++++++++----- 3 files changed, 111 insertions(+), 19 deletions(-) create mode 100644 src/tensor_ops/reductions.txt diff --git a/src/tensor_ops/reductions.txt b/src/tensor_ops/reductions.txt new file mode 100644 index 000000000..88ee2e5bf --- /dev/null +++ b/src/tensor_ops/reductions.txt @@ -0,0 +1,31 @@ +have threads get elements in contiguous order and store into shared memory buffer + +segmentSum(i, block_i, segment_size, buf) + segment_i = i % segment_len + segment_start = i - segment_i + segment_end = segment_start + segment_size + + if block_i < segment_i + segment_idx = block_i + segment_len = segment_size - (segment_i - block_i) + else if segment_end > BLOCK_SIZE + segment_idx = segment_i + segment_len = BLOCK_SIZE - segment_start + else + segment_idx = segment_i + segment_len = segment_size + + stop = min(BLOCK_SIZE, segment_size) + segment = buf + block_i - segment_idx + + for (unsigned int s=1; s < stop; s *= 2) { + int index = 2 * s * tid; + + if (index < segmnet_len) { + segment[index] += segment[index + s]; + } + __syncthreads(); + } + + if segment_idx == 0 + out[i / segment_size] += segment[0] diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index e10a2989a..9a8afe9c8 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -6,6 +6,7 @@ use crate::{ use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; use std::sync::Arc; +use std::vec::Vec; const MODULE_NAME: &str = "sum_to"; const FWD_FN_NAME: &str = "sum_to_forward"; @@ -13,6 +14,16 @@ const BWD_FN_NAME: &str = "sum_to_backward"; const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/sum_to.ptx")); +pub fn remove_broadcasted_dims(dims: I, strides: I) -> (Vec, Vec) + where I: IntoIterator +{ + dims + .into_iter() + .zip(strides.into_iter()) + .filter(|(_, stride)| *stride != 0) + .unzip() +} + impl super::SumKernel for Cuda { fn forward( &self, @@ -29,11 +40,10 @@ impl super::SumKernel for Cuda { let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - let dims: CudaSlice = self.dev.take_async(inp.shape.concrete().into())?; - let inp_strides: CudaSlice = self.dev.take_async(inp.strides.into())?; - let out_strides: Src::Concrete = - BroadcastStridesTo::::broadcast_strides(&dst, dst.strides()); - let out_strides: CudaSlice = self.dev.take_async(out_strides.into())?; + let (dims, strides) = remove_broadcasted_dims(inp.shape.concrete(), inp.strides); + let num_dims = dims.len(); + let dims: CudaSlice = self.dev.take_async(dims)?; + let inp_strides: CudaSlice = self.dev.take_async(strides)?; let mut storage = self.dev.alloc_zeros_async::(dst.num_elements())?; @@ -41,16 +51,18 @@ impl super::SumKernel for Cuda { let virtual_numel = inp.shape.num_elements(); let elems_per_thread = (virtual_numel / physical_numel) as f32; + let chunk_len = physical_numel / dst.num_elements(); + let cfg = LaunchConfig::for_num_elems(physical_numel as u32); let params = ( physical_numel, // const size_t numel, - Src::NUM_DIMS, // const size_t num_dims, + num_dims, // const size_t num_dims, elems_per_thread, // const float elems_per_thread, + chunk_len, // const size_t chunk_len, &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, &inp_strides, // const size_t *inp_strides, - &mut storage, // float *out, - &out_strides, // const size_t *out_strides + &mut storage, // float *out ); unsafe { fwd_fn.launch_async(cfg, params) }?; Ok(CudaArray { diff --git a/src/tensor_ops/sum_to/sum_to.cu b/src/tensor_ops/sum_to/sum_to.cu index 15b05de59..c2aec2aed 100644 --- a/src/tensor_ops/sum_to/sum_to.cu +++ b/src/tensor_ops/sum_to/sum_to.cu @@ -27,29 +27,78 @@ __device__ unsigned int get_unstrided_index( return idx; } -// Accepts pre-broadcasted strides for both input & output. -// So both inp & out are expected to be broadcasted to the same size. +__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { + // Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +// Efficiently computes the sum of each chunk in "data" of size chunk_len, and +// stores the sums in out[i / chunk_len] +__device__ void chunk_sum( + const size_t numel, + const size_t chunk_len, + const float data, + float* out +) { + __shared__ float buf[1024]; + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end) { + // This is sount because all threads read and write at the same time + buf[block_i] += buf[block_i_2]; + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + atomicAdd(out + i / chunk_len, buf[block_i]); + } +} + +// inp_strides and dims must have broadcasted dimensions removed extern "C" __global__ void sum_to_forward( const size_t numel, const size_t num_dims, const float elems_per_thread, + const size_t chunk_len, const size_t *dims, const float *inp, const size_t *inp_strides, - float *out, - const size_t *out_strides + float *out ) { - unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - if (inp_i >= numel) { + if (i >= numel) { return; } - auto tmp = inp[inp_i]; - - unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides); - unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); - atomicAdd(out + out_i, tmp * elems_per_thread); + unsigned int inp_i = get_strided_index(i, num_dims, dims, inp_strides); + chunk_sum(numel, chunk_len, inp[inp_i], out); } // Accepts pre-broadcasted strides for both input & output. From 0161a3a1568504b8b9a77fbbbf5a89f9c46cebc6 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Thu, 19 Jan 2023 23:17:20 -0600 Subject: [PATCH 03/11] fix stride computation for sum_to cuda kernel --- src/tensor_ops/sum_to/cuda_kernel.rs | 30 +++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index 9a8afe9c8..6464bccb7 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -14,9 +14,30 @@ const BWD_FN_NAME: &str = "sum_to_backward"; const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/sum_to.ptx")); -pub fn remove_broadcasted_dims(dims: I, strides: I) -> (Vec, Vec) - where I: IntoIterator +fn permute_with(vec: &mut Vec) { + let mut tmp = vec + .iter() + .enumerate() + .map(|(i, x)| (Ax::as_array().into_iter().any(|x| x == i as isize), *x)) + .collect::>(); + + // requires stable sorting + tmp.sort_by_key(|x| x.0); + + for (v, t) in vec.iter_mut().zip(tmp.iter()) { + *v = t.1; + } +} + +fn get_sum_dims_strides(dims: I, strides: I) -> (Vec, Vec) + where I: IntoIterator, { + let mut dims: Vec = dims.into_iter().collect(); + let mut strides: Vec = strides.into_iter().collect(); + + permute_with::(&mut dims); + permute_with::(&mut strides); + dims .into_iter() .zip(strides.into_iter()) @@ -31,7 +52,7 @@ impl super::SumKernel for Cuda { inp: &Self::Storage, ) -> Result, Self::Err> where - Src: ReduceShapeTo, + Src: ReduceShapeTo { if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) { self.dev @@ -39,8 +60,7 @@ impl super::SumKernel for Cuda { } let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - - let (dims, strides) = remove_broadcasted_dims(inp.shape.concrete(), inp.strides); + let (dims, strides) = get_sum_dims_strides::<_, Ax>(inp.shape.concrete(), inp.strides); let num_dims = dims.len(); let dims: CudaSlice = self.dev.take_async(dims)?; let inp_strides: CudaSlice = self.dev.take_async(strides)?; From a7248fd03d1deb1a098dd493be41faf1eacd6ce1 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Thu, 19 Jan 2023 23:38:40 -0600 Subject: [PATCH 04/11] run cargo fmt; efficiency/readibility changes --- src/optim/adam/cuda_kernel.rs | 13 ++++++------- src/optim/adam/mod.rs | 3 ++- src/tensor_ops/sum_to/cuda_kernel.rs | 26 +++++++++++++------------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/optim/adam/cuda_kernel.rs b/src/optim/adam/cuda_kernel.rs index 9384a9626..c72239520 100644 --- a/src/optim/adam/cuda_kernel.rs +++ b/src/optim/adam/cuda_kernel.rs @@ -1,8 +1,8 @@ -use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig}; -use std::sync::Arc; -use crate::{shapes::Shape, tensor::Cuda}; use super::AdamConfig; use crate::optim::optimizer::*; +use crate::{shapes::Shape, tensor::Cuda}; +use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig}; +use std::sync::Arc; #[repr(C)] struct CudaAdamConfig { @@ -16,7 +16,7 @@ struct CudaAdamConfig { unsafe impl AsKernelParam for CudaAdamConfig {} -fn adam_config_to_cuda(config: &AdamConfig) -> CudaAdamConfig { +fn adam_config_to_cuda(config: &AdamConfig) -> CudaAdamConfig { let (weight_decay_type, weight_decay) = weight_decay_to_cuda(config.weight_decay); CudaAdamConfig { @@ -25,7 +25,7 @@ fn adam_config_to_cuda(config: &AdamConfig) -> CudaAdamConfi beta2: config.betas[1], eps: config.eps, weight_decay_type, - weight_decay + weight_decay, } } @@ -48,8 +48,7 @@ impl super::AdamKernel for Cuda { debug_assert_eq!(param.strides, grad.strides); if !self.dev.has_func(MODULE_NAME, FN_NAME) { - self.dev - .load_ptx(PTX_SRC.into(), MODULE_NAME, &[FN_NAME])?; + self.dev.load_ptx(PTX_SRC.into(), MODULE_NAME, &[FN_NAME])?; } let adam_cfg = adam_config_to_cuda(cfg); diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 5701ca715..8d93cc9be 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -137,7 +137,8 @@ impl, E: Dtype> ParamUpdater for Adam< Some(g) => { let m_t = self.moment1.get_or_alloc_mut(p)?; let v_t = self.moment2.get_or_alloc_mut(p)?; - p.device.update(self.t, &self.cfg, &mut p.storage, m_t, v_t, g)?; + p.device + .update(self.t, &self.cfg, &mut p.storage, m_t, v_t, g)?; } } Ok(()) diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index 6464bccb7..d282c44f4 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -14,32 +14,32 @@ const BWD_FN_NAME: &str = "sum_to_backward"; const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/sum_to.ptx")); -fn permute_with(vec: &mut Vec) { - let mut tmp = vec - .iter() - .enumerate() - .map(|(i, x)| (Ax::as_array().into_iter().any(|x| x == i as isize), *x)) - .collect::>(); +fn permute_axes_to_end(vec: &mut Vec) { + let mut tmp = vec.iter().map(|x| (false, *x)).collect::>(); + + for i in Ax::as_array().into_iter() { + tmp[i as usize].0 = true; + } // requires stable sorting tmp.sort_by_key(|x| x.0); for (v, t) in vec.iter_mut().zip(tmp.iter()) { *v = t.1; - } + } } fn get_sum_dims_strides(dims: I, strides: I) -> (Vec, Vec) - where I: IntoIterator, +where + I: IntoIterator, { let mut dims: Vec = dims.into_iter().collect(); let mut strides: Vec = strides.into_iter().collect(); - permute_with::(&mut dims); - permute_with::(&mut strides); + permute_axes_to_end::(&mut dims); + permute_axes_to_end::(&mut strides); - dims - .into_iter() + dims.into_iter() .zip(strides.into_iter()) .filter(|(_, stride)| *stride != 0) .unzip() @@ -52,7 +52,7 @@ impl super::SumKernel for Cuda { inp: &Self::Storage, ) -> Result, Self::Err> where - Src: ReduceShapeTo + Src: ReduceShapeTo, { if !self.dev.has_func(MODULE_NAME, FWD_FN_NAME) { self.dev From 6fe431c68588a9350f91a41bd9aa2e5d6065d782 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 00:48:34 -0600 Subject: [PATCH 05/11] rename funciton; fix bugs in chunk_sum; add test for chunk_sum --- src/tensor_ops/sum_to/cuda_kernel.rs | 6 ++++-- src/tensor_ops/sum_to/mod.rs | 10 ++++++++++ src/tensor_ops/sum_to/sum_to.cu | 7 +++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index d282c44f4..0bdec8c36 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -29,7 +29,7 @@ fn permute_axes_to_end(vec: &mut Vec) { } } -fn get_sum_dims_strides(dims: I, strides: I) -> (Vec, Vec) +fn reduction_dims_and_strides(dims: I, strides: I) -> (Vec, Vec) where I: IntoIterator, { @@ -60,7 +60,9 @@ impl super::SumKernel for Cuda { } let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - let (dims, strides) = get_sum_dims_strides::<_, Ax>(inp.shape.concrete(), inp.strides); + + let (dims, strides) = + reduction_dims_and_strides::<_, Ax>(inp.shape.concrete(), inp.strides); let num_dims = dims.len(); let dims: CudaSlice = self.dev.take_async(dims)?; let inp_strides: CudaSlice = self.dev.take_async(strides)?; diff --git a/src/tensor_ops/sum_to/mod.rs b/src/tensor_ops/sum_to/mod.rs index 4e474b84f..a4721d8c0 100644 --- a/src/tensor_ops/sum_to/mod.rs +++ b/src/tensor_ops/sum_to/mod.rs @@ -125,4 +125,14 @@ mod tests { let g2 = r2.sum().backward(); assert_close(&g.get(&t).array(), &g2.get(&t).array()); } + + #[test] + fn test_sum_chunking() { + let dev: TestDevice = Default::default(); + let t = dev.tensor([[1.0; 100]; 60]); + let r = t.trace().sum::, _>(); + assert_eq!(r.array(), [100.0; 60]); + // let g = r.sum().backward(); + // assert_close(&g.get(&t).array(), &g2.get(&t).array()); + } } diff --git a/src/tensor_ops/sum_to/sum_to.cu b/src/tensor_ops/sum_to/sum_to.cu index c2aec2aed..7249de0fc 100644 --- a/src/tensor_ops/sum_to/sum_to.cu +++ b/src/tensor_ops/sum_to/sum_to.cu @@ -57,6 +57,8 @@ __device__ void chunk_sum( unsigned int chunk_start = max((int)(block_i - chunk_i), 0); unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + chunk_i = block_i - chunk_start; + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); size_t incr = next_power_of_two(max_chunk_len) >> 1; @@ -67,8 +69,9 @@ __device__ void chunk_sum( for (; incr > 0; incr >>= 1) { unsigned int block_i_2 = block_i + incr; - if (block_i_2 < chunk_end) { - // This is sount because all threads read and write at the same time + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur buf[block_i] += buf[block_i_2]; } From 6e79ba5180bd25bb1229e7eb7654d3621d9c9bf9 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 10:57:28 -0600 Subject: [PATCH 06/11] rename funciton; fix bugs in chunk_sum; more tests for sum --- src/lib.rs | 4 ++++ src/tensor_ops/sum_to/mod.rs | 18 +++++++++++++++--- src/tensor_ops/sum_to/sum_to.cu | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5bd0343d5..3d9da96c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -210,6 +210,10 @@ pub(crate) mod tests { pub fn assert_close(a: &T, b: &T) { a.assert_close(b, TOLERANCE); } + + pub fn assert_close_with_tolerance(a: &T, b: &T, tolerance: f32) { + a.assert_close(b, tolerance); + } } /// Used to assert things about const generics diff --git a/src/tensor_ops/sum_to/mod.rs b/src/tensor_ops/sum_to/mod.rs index a4721d8c0..69fc951fc 100644 --- a/src/tensor_ops/sum_to/mod.rs +++ b/src/tensor_ops/sum_to/mod.rs @@ -78,7 +78,7 @@ impl, T: Tape> SumTo for Tensor, _>(rand_distr::StandardNormal); + let t2 = t1.clone().broadcast::, _>(); + let r1 = t1.trace().sum::, _>() * 5.0; + let r2 = t2.trace().sum::, _>(); + assert_close_with_tolerance(&r1.array(), &r2.array(), 3e-6); + let g = r1.sum().backward(); + assert_close(&g.get(&t1).array(), &[[5.0; 3]; 4]); + } + #[test] fn test_sum_chunking() { let dev: TestDevice = Default::default(); let t = dev.tensor([[1.0; 100]; 60]); let r = t.trace().sum::, _>(); assert_eq!(r.array(), [100.0; 60]); - // let g = r.sum().backward(); - // assert_close(&g.get(&t).array(), &g2.get(&t).array()); + let g = r.sum().backward(); + assert_close(&g.get(&t).array(), &t.array()); } } diff --git a/src/tensor_ops/sum_to/sum_to.cu b/src/tensor_ops/sum_to/sum_to.cu index 7249de0fc..931908d6d 100644 --- a/src/tensor_ops/sum_to/sum_to.cu +++ b/src/tensor_ops/sum_to/sum_to.cu @@ -101,7 +101,7 @@ extern "C" __global__ void sum_to_forward( } unsigned int inp_i = get_strided_index(i, num_dims, dims, inp_strides); - chunk_sum(numel, chunk_len, inp[inp_i], out); + chunk_sum(numel, chunk_len, inp[inp_i] * elems_per_thread, out); } // Accepts pre-broadcasted strides for both input & output. From 88cf72ebd403d0dfb503475131ea0972dfbb7d85 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 11:45:12 -0600 Subject: [PATCH 07/11] simplify, document, and rename permute_for_reductions --- src/lib.rs | 6 ++++- src/tensor_ops/sum_to/cuda_kernel.rs | 38 ++++++++++++---------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3d9da96c9..361d52566 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -211,7 +211,11 @@ pub(crate) mod tests { a.assert_close(b, TOLERANCE); } - pub fn assert_close_with_tolerance(a: &T, b: &T, tolerance: f32) { + pub fn assert_close_with_tolerance( + a: &T, + b: &T, + tolerance: f32, + ) { a.assert_close(b, tolerance); } } diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index 0bdec8c36..b3744af76 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -14,33 +14,28 @@ const BWD_FN_NAME: &str = "sum_to_backward"; const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/sum_to.ptx")); -fn permute_axes_to_end(vec: &mut Vec) { - let mut tmp = vec.iter().map(|x| (false, *x)).collect::>(); +/// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions +/// so that a cuda kernel called for each physical element of the input tensor will place elements +/// to be reduced with each other next to each other in memory. +fn permute_for_reductions(dims: I, strides: I) -> (Vec, Vec) +where + I: IntoIterator, +{ + let mut tmp = dims + .into_iter() + .zip(strides.into_iter()) + .map(|x| (false, x)) + .collect::>(); for i in Ax::as_array().into_iter() { tmp[i as usize].0 = true; } - // requires stable sorting + // requires stable sorting to keep non-summed axes in the correct order tmp.sort_by_key(|x| x.0); - for (v, t) in vec.iter_mut().zip(tmp.iter()) { - *v = t.1; - } -} - -fn reduction_dims_and_strides(dims: I, strides: I) -> (Vec, Vec) -where - I: IntoIterator, -{ - let mut dims: Vec = dims.into_iter().collect(); - let mut strides: Vec = strides.into_iter().collect(); - - permute_axes_to_end::(&mut dims); - permute_axes_to_end::(&mut strides); - - dims.into_iter() - .zip(strides.into_iter()) + tmp.into_iter() + .map(|(_, x)| x) .filter(|(_, stride)| *stride != 0) .unzip() } @@ -61,8 +56,7 @@ impl super::SumKernel for Cuda { let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - let (dims, strides) = - reduction_dims_and_strides::<_, Ax>(inp.shape.concrete(), inp.strides); + let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); let num_dims = dims.len(); let dims: CudaSlice = self.dev.take_async(dims)?; let inp_strides: CudaSlice = self.dev.take_async(strides)?; From 6cc627870b72c655832ce2125057623f9cc98d31 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 12:19:18 -0600 Subject: [PATCH 08/11] move permute_for_reductions; optimize max_to cuda kernel --- src/tensor_ops/internal_reshapes.rs | 31 ++++++++++ src/tensor_ops/max_to/cuda_kernel.rs | 34 +++++----- src/tensor_ops/max_to/max_to.cu | 92 ++++++++++++++++++++++++---- src/tensor_ops/max_to/mod.rs | 2 +- src/tensor_ops/mod.rs | 1 + src/tensor_ops/reductions.txt | 31 ---------- src/tensor_ops/sum_to/cuda_kernel.rs | 28 +-------- 7 files changed, 134 insertions(+), 85 deletions(-) create mode 100644 src/tensor_ops/internal_reshapes.rs delete mode 100644 src/tensor_ops/reductions.txt diff --git a/src/tensor_ops/internal_reshapes.rs b/src/tensor_ops/internal_reshapes.rs new file mode 100644 index 000000000..632485443 --- /dev/null +++ b/src/tensor_ops/internal_reshapes.rs @@ -0,0 +1,31 @@ +#[cfg(feature = "cuda")] +use crate::prelude::Axes; +#[cfg(feature = "cuda")] +use std::vec::Vec; + +/// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions +/// so that a cuda kernel called for each physical element of the input tensor will place elements +/// to be reduced with each other next to each other in memory. +#[cfg(feature = "cuda")] +pub(super) fn permute_for_reductions(dims: I, strides: I) -> (Vec, Vec) +where + I: IntoIterator, +{ + let mut tmp = dims + .into_iter() + .zip(strides.into_iter()) + .map(|x| (false, x)) + .collect::>(); + + for i in Ax::as_array().into_iter() { + tmp[i as usize].0 = true; + } + + // requires stable sorting to keep non-summed axes in the correct order + tmp.sort_by_key(|x| x.0); + + tmp.into_iter() + .map(|(_, x)| x) + .filter(|(_, stride)| *stride != 0) + .unzip() +} diff --git a/src/tensor_ops/max_to/cuda_kernel.rs b/src/tensor_ops/max_to/cuda_kernel.rs index c628d2605..e785f0bf6 100644 --- a/src/tensor_ops/max_to/cuda_kernel.rs +++ b/src/tensor_ops/max_to/cuda_kernel.rs @@ -2,6 +2,7 @@ use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; @@ -38,21 +39,22 @@ impl super::MaxReduceKernel for Cuda { let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - let dims: CudaSlice = self.dev.take_async(inp.shape.concrete().into())?; - let inp_strides: CudaSlice = self.dev.take_async(inp.strides.into())?; - let out_strides = BroadcastStridesTo::::broadcast_strides(&dst, dst.strides()); - let out_strides: CudaSlice = self.dev.take_async(out_strides.into())?; + let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); + let dims: CudaSlice = self.dev.take_async(dims)?; + let inp_strides: CudaSlice = self.dev.take_async(strides)?; + + let physical_numel = inp.data.len(); + let chunk_len = physical_numel / dst.num_elements(); - let inp_numel = inp.shape.num_elements(); - let cfg = LaunchConfig::for_num_elems(inp_numel as u32); + let cfg = LaunchConfig::for_num_elems(physical_numel as u32); let params = ( - inp_numel, // size_t numel, - Src::NUM_DIMS, // size_t num_dims, + physical_numel, // const size_t numel, + dims.len(), // const size_t num_dims, + chunk_len, // const size_t chunk_len, &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, &inp_strides, // const size_t *inp_strides, - &mut storage, // float *out, - &out_strides, // const size_t *out_strides + &mut storage, // float *out ); unsafe { fwd_fn.launch_async(cfg, params) }?; Ok(CudaArray { @@ -80,11 +82,15 @@ impl super::MaxReduceKernel for Cuda { BroadcastStridesTo::::broadcast_strides(&grad_out.shape, grad_out.strides); let out_strides: CudaSlice = self.dev.take_async(out_strides.into())?; - let inp_numel = grad_inp.shape.num_elements(); - let cfg = LaunchConfig::for_num_elems(inp_numel as u32); + let physical_numel = grad_inp.data.len(); + let virtual_numel = grad_inp.shape.num_elements(); + let elems_per_thread = (virtual_numel / physical_numel) as f32; + + let cfg = LaunchConfig::for_num_elems(physical_numel as u32); let params = ( - inp_numel, // size_t numel, - Src::NUM_DIMS, // size_t num_dims, + physical_numel, // const size_t numel, + Src::NUM_DIMS, // const size_t num_dims, + elems_per_thread as f32, // const float elems_per_thread, &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, Arc::make_mut(&mut grad_inp.data), // float *grad_inp, diff --git a/src/tensor_ops/max_to/max_to.cu b/src/tensor_ops/max_to/max_to.cu index 4fd084c87..6d528fdfd 100644 --- a/src/tensor_ops/max_to/max_to.cu +++ b/src/tensor_ops/max_to/max_to.cu @@ -23,6 +23,20 @@ __device__ unsigned int get_strided_index( return strided_i; } +__device__ unsigned int get_unstrided_index( + const unsigned int strided_i, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int idx = 0; + for (unsigned int d = 0; d < num_dims; d++) { + idx *= dims[d]; + idx += strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]; + } + return idx; +} + extern "C" __global__ void fill_with(float *buf, float value, const size_t numel) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= numel) { @@ -31,16 +45,71 @@ extern "C" __global__ void fill_with(float *buf, float value, const size_t numel buf[i] = value; } +__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { + // Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +// Efficiently computes the max of each chunk in "data" of size chunk_len, and +// stores the maximums in out[i / chunk_len] +__device__ void chunk_max( + const size_t numel, + const size_t chunk_len, + const float data, + float* out +) { + __shared__ float buf[1024]; + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + chunk_i = block_i - chunk_start; + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] = fmaxf(buf[block_i], buf[block_i_2]); + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + atomicMaxf(out + i / chunk_len, buf[block_i]); + } +} // Accepts pre-broadcasted strides for both input & output. // So both inp & out are expected to be broadcasted to the same size. extern "C" __global__ void max_to_forward( const size_t numel, const size_t num_dims, + const size_t chunk_len, const size_t *dims, const float *inp, const size_t *inp_strides, - float *out, - const size_t *out_strides + float *out ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -48,10 +117,8 @@ extern "C" __global__ void max_to_forward( return; } - unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides); - unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides); - - atomicMaxf(out + out_strided_i, inp[inp_strided_i]); + unsigned int inp_i = get_strided_index(i, num_dims, dims, inp_strides); + chunk_max(numel, chunk_len, inp[inp_i], out); } // Accepts pre-broadcasted strides for both input & output. @@ -59,6 +126,7 @@ extern "C" __global__ void max_to_forward( extern "C" __global__ void max_to_backward( const size_t numel, const size_t num_dims, + const float elems_per_thread, const size_t *dims, const float *inp, float *grad_inp, @@ -67,15 +135,15 @@ extern "C" __global__ void max_to_backward( const float *grad_out, const size_t *out_strides ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { + if (inp_i >= numel) { return; } - unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides); - unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides); + unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides); + unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); - auto tmp = inp[inp_strided_i] == out[out_strided_i] ? grad_out[out_strided_i] : 0.0; - atomicAdd(grad_inp + inp_strided_i, tmp); + auto tmp = inp[inp_i] == out[out_i] ? grad_out[out_i] : 0.0; + grad_inp[inp_i] += tmp * elems_per_thread; } diff --git a/src/tensor_ops/max_to/mod.rs b/src/tensor_ops/max_to/mod.rs index 929193d98..0e849c403 100644 --- a/src/tensor_ops/max_to/mod.rs +++ b/src/tensor_ops/max_to/mod.rs @@ -87,7 +87,7 @@ mod tests { use crate::tests::{assert_close, TestDevice}; #[test] - fn test_valids_max_axis() { + fn test_max_valid_axes() { let dev: TestDevice = Default::default(); let _ = dev.zeros::>().max::(); let _ = dev.zeros::>().max::, _>(); diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index df6ae33b0..49d18cf18 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -182,6 +182,7 @@ pub(crate) mod cpu_kernels; #[cfg(feature = "cuda")] pub(crate) mod cuda_kernels; pub(crate) mod ops; +mod internal_reshapes; pub use abs::abs; pub use add::{add, TryAdd}; diff --git a/src/tensor_ops/reductions.txt b/src/tensor_ops/reductions.txt deleted file mode 100644 index 88ee2e5bf..000000000 --- a/src/tensor_ops/reductions.txt +++ /dev/null @@ -1,31 +0,0 @@ -have threads get elements in contiguous order and store into shared memory buffer - -segmentSum(i, block_i, segment_size, buf) - segment_i = i % segment_len - segment_start = i - segment_i - segment_end = segment_start + segment_size - - if block_i < segment_i - segment_idx = block_i - segment_len = segment_size - (segment_i - block_i) - else if segment_end > BLOCK_SIZE - segment_idx = segment_i - segment_len = BLOCK_SIZE - segment_start - else - segment_idx = segment_i - segment_len = segment_size - - stop = min(BLOCK_SIZE, segment_size) - segment = buf + block_i - segment_idx - - for (unsigned int s=1; s < stop; s *= 2) { - int index = 2 * s * tid; - - if (index < segmnet_len) { - segment[index] += segment[index + s]; - } - __syncthreads(); - } - - if segment_idx == 0 - out[i / segment_size] += segment[0] diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index b3744af76..ba41127c7 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -2,11 +2,11 @@ use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; use std::sync::Arc; -use std::vec::Vec; const MODULE_NAME: &str = "sum_to"; const FWD_FN_NAME: &str = "sum_to_forward"; @@ -14,32 +14,6 @@ const BWD_FN_NAME: &str = "sum_to_backward"; const ALL_FN_NAMES: [&str; 2] = [FWD_FN_NAME, BWD_FN_NAME]; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/sum_to.ptx")); -/// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions -/// so that a cuda kernel called for each physical element of the input tensor will place elements -/// to be reduced with each other next to each other in memory. -fn permute_for_reductions(dims: I, strides: I) -> (Vec, Vec) -where - I: IntoIterator, -{ - let mut tmp = dims - .into_iter() - .zip(strides.into_iter()) - .map(|x| (false, x)) - .collect::>(); - - for i in Ax::as_array().into_iter() { - tmp[i as usize].0 = true; - } - - // requires stable sorting to keep non-summed axes in the correct order - tmp.sort_by_key(|x| x.0); - - tmp.into_iter() - .map(|(_, x)| x) - .filter(|(_, stride)| *stride != 0) - .unzip() -} - impl super::SumKernel for Cuda { fn forward( &self, From 0cd928590cb266205869458ed5f7ee9f83fee2cd Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 12:31:01 -0600 Subject: [PATCH 09/11] readability tweaks --- src/tensor_ops/max_to/cuda_kernel.rs | 6 +++--- src/tensor_ops/max_to/max_to.cu | 14 +++++++------- src/tensor_ops/sum_to/cuda_kernel.rs | 6 +++--- src/tensor_ops/sum_to/sum_to.cu | 12 ++++++------ 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/tensor_ops/max_to/cuda_kernel.rs b/src/tensor_ops/max_to/cuda_kernel.rs index e785f0bf6..a221dcb3e 100644 --- a/src/tensor_ops/max_to/cuda_kernel.rs +++ b/src/tensor_ops/max_to/cuda_kernel.rs @@ -41,7 +41,7 @@ impl super::MaxReduceKernel for Cuda { let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); let dims: CudaSlice = self.dev.take_async(dims)?; - let inp_strides: CudaSlice = self.dev.take_async(strides)?; + let strides: CudaSlice = self.dev.take_async(strides)?; let physical_numel = inp.data.len(); let chunk_len = physical_numel / dst.num_elements(); @@ -51,9 +51,9 @@ impl super::MaxReduceKernel for Cuda { physical_numel, // const size_t numel, dims.len(), // const size_t num_dims, chunk_len, // const size_t chunk_len, - &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, - &inp_strides, // const size_t *inp_strides, + &dims, // const size_t *dims, + &strides, // const size_t *strides, &mut storage, // float *out ); unsafe { fwd_fn.launch_async(cfg, params) }?; diff --git a/src/tensor_ops/max_to/max_to.cu b/src/tensor_ops/max_to/max_to.cu index 6d528fdfd..9cef77960 100644 --- a/src/tensor_ops/max_to/max_to.cu +++ b/src/tensor_ops/max_to/max_to.cu @@ -45,14 +45,13 @@ extern "C" __global__ void fill_with(float *buf, float value, const size_t numel buf[i] = value; } +// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 __device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { - // Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 v--; v |= v >> 1; v |= v >> 2; v |= v >> 4; v |= v >> 8; - v |= v >> 16; v++; return v; } @@ -100,15 +99,16 @@ __device__ void chunk_max( atomicMaxf(out + i / chunk_len, buf[block_i]); } } -// Accepts pre-broadcasted strides for both input & output. -// So both inp & out are expected to be broadcasted to the same size. + +// strides and dims specify how to index inp to put all summed elements next to +// each other, and chunk_len is len(inp) / len(out) extern "C" __global__ void max_to_forward( const size_t numel, const size_t num_dims, const size_t chunk_len, - const size_t *dims, const float *inp, - const size_t *inp_strides, + const size_t *dims, + const size_t *strides, float *out ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -117,7 +117,7 @@ extern "C" __global__ void max_to_forward( return; } - unsigned int inp_i = get_strided_index(i, num_dims, dims, inp_strides); + unsigned int inp_i = get_strided_index(i, num_dims, dims, strides); chunk_max(numel, chunk_len, inp[inp_i], out); } diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index ba41127c7..f4ab4160b 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -33,7 +33,7 @@ impl super::SumKernel for Cuda { let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); let num_dims = dims.len(); let dims: CudaSlice = self.dev.take_async(dims)?; - let inp_strides: CudaSlice = self.dev.take_async(strides)?; + let strides: CudaSlice = self.dev.take_async(strides)?; let mut storage = self.dev.alloc_zeros_async::(dst.num_elements())?; @@ -49,9 +49,9 @@ impl super::SumKernel for Cuda { num_dims, // const size_t num_dims, elems_per_thread, // const float elems_per_thread, chunk_len, // const size_t chunk_len, - &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, - &inp_strides, // const size_t *inp_strides, + &dims, // const size_t *dims, + &strides, // const size_t *strides, &mut storage, // float *out ); unsafe { fwd_fn.launch_async(cfg, params) }?; diff --git a/src/tensor_ops/sum_to/sum_to.cu b/src/tensor_ops/sum_to/sum_to.cu index 931908d6d..dc9d539c5 100644 --- a/src/tensor_ops/sum_to/sum_to.cu +++ b/src/tensor_ops/sum_to/sum_to.cu @@ -27,14 +27,13 @@ __device__ unsigned int get_unstrided_index( return idx; } +// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 __device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { - // Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 v--; v |= v >> 1; v |= v >> 2; v |= v >> 4; v |= v >> 8; - v |= v >> 16; v++; return v; } @@ -83,15 +82,16 @@ __device__ void chunk_sum( } } -// inp_strides and dims must have broadcasted dimensions removed +// strides and dims specify how to index inp to put all summed elements next to +// each other, and chunk_len is len(inp) / len(out) extern "C" __global__ void sum_to_forward( const size_t numel, const size_t num_dims, const float elems_per_thread, const size_t chunk_len, - const size_t *dims, const float *inp, - const size_t *inp_strides, + const size_t *dims, + const size_t *strides, float *out ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -100,7 +100,7 @@ extern "C" __global__ void sum_to_forward( return; } - unsigned int inp_i = get_strided_index(i, num_dims, dims, inp_strides); + unsigned int inp_i = get_strided_index(i, num_dims, dims, strides); chunk_sum(numel, chunk_len, inp[inp_i] * elems_per_thread, out); } From fc8c55cc194923fcea060a309b966e4b351a580d Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 13:16:16 -0600 Subject: [PATCH 10/11] implement min_to cuda kernel --- src/tensor_ops/min_to/cuda_kernel.rs | 38 +++++----- src/tensor_ops/min_to/min_to.cu | 100 ++++++++++++++++++++++----- src/tensor_ops/min_to/mod.rs | 2 +- 3 files changed, 107 insertions(+), 33 deletions(-) diff --git a/src/tensor_ops/min_to/cuda_kernel.rs b/src/tensor_ops/min_to/cuda_kernel.rs index 4b6222c29..b828d14b5 100644 --- a/src/tensor_ops/min_to/cuda_kernel.rs +++ b/src/tensor_ops/min_to/cuda_kernel.rs @@ -2,6 +2,7 @@ use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; @@ -38,21 +39,22 @@ impl super::MinReduceKernel for Cuda { let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap(); - let dims: CudaSlice = self.dev.take_async(inp.shape.concrete().into())?; - let inp_strides: CudaSlice = self.dev.take_async(inp.strides.into())?; - let out_strides = BroadcastStridesTo::::broadcast_strides(&dst, dst.strides()); - let out_strides: CudaSlice = self.dev.take_async(out_strides.into())?; + let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides); + let dims: CudaSlice = self.dev.take_async(dims)?; + let strides: CudaSlice = self.dev.take_async(strides)?; + + let physical_numel = inp.data.len(); + let chunk_len = physical_numel / dst.num_elements(); - let inp_numel = inp.shape.num_elements(); - let cfg = LaunchConfig::for_num_elems(inp_numel as u32); + let cfg = LaunchConfig::for_num_elems(physical_numel as u32); let params = ( - inp_numel, // size_t numel, - Src::NUM_DIMS, // size_t num_dims, - &dims, // const size_t *dims, + physical_numel, // const size_t numel, + dims.len(), // const size_t num_dims, + chunk_len, // const size_t chunk_len, inp.data.as_ref(), // const float *inp, - &inp_strides, // const size_t *inp_strides, - &mut storage, // float *out, - &out_strides, // const size_t *out_strides + &dims, // const size_t *dims, + &strides, // const size_t *strides, + &mut storage, // float *out ); unsafe { fwd_fn.launch_async(cfg, params) }?; Ok(CudaArray { @@ -80,11 +82,15 @@ impl super::MinReduceKernel for Cuda { BroadcastStridesTo::::broadcast_strides(&grad_out.shape, grad_out.strides); let out_strides: CudaSlice = self.dev.take_async(out_strides.into())?; - let inp_numel = grad_inp.shape.num_elements(); - let cfg = LaunchConfig::for_num_elems(inp_numel as u32); + let physical_numel = grad_inp.data.len(); + let virtual_numel = grad_inp.shape.num_elements(); + let elems_per_thread = (virtual_numel / physical_numel) as f32; + + let cfg = LaunchConfig::for_num_elems(physical_numel as u32); let params = ( - inp_numel, // size_t numel, - Src::NUM_DIMS, // size_t num_dims, + physical_numel, // const size_t numel, + Src::NUM_DIMS, // const size_t num_dims, + elems_per_thread as f32, // const float elems_per_thread, &dims, // const size_t *dims, inp.data.as_ref(), // const float *inp, Arc::make_mut(&mut grad_inp.data), // float *grad_inp, diff --git a/src/tensor_ops/min_to/min_to.cu b/src/tensor_ops/min_to/min_to.cu index 512023e75..44b73b585 100644 --- a/src/tensor_ops/min_to/min_to.cu +++ b/src/tensor_ops/min_to/min_to.cu @@ -23,6 +23,20 @@ __device__ unsigned int get_strided_index( return strided_i; } +__device__ unsigned int get_unstrided_index( + const unsigned int strided_i, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int idx = 0; + for (unsigned int d = 0; d < num_dims; d++) { + idx *= dims[d]; + idx += strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]; + } + return idx; +} + extern "C" __global__ void fill_with(float *buf, float value, const size_t numel) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= numel) { @@ -31,16 +45,71 @@ extern "C" __global__ void fill_with(float *buf, float value, const size_t numel buf[i] = value; } -// Accepts pre-broadcasted strides for both input & output. -// So both inp & out are expected to be broadcasted to the same size. +// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 +__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v++; + return v; +} + +// Efficiently computes the min of each chunk in "data" of size chunk_len, and +// stores the minimums in out[i / chunk_len] +__device__ void chunk_min( + const size_t numel, + const size_t chunk_len, + const float data, + float* out +) { + __shared__ float buf[1024]; + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + chunk_i = block_i - chunk_start; + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] = fminf(buf[block_i], buf[block_i_2]); + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + atomicMinf(out + i / chunk_len, buf[block_i]); + } +} + +// strides and dims specify how to index inp to put all summed elements next to +// each other, and chunk_len is len(inp) / len(out) extern "C" __global__ void min_to_forward( const size_t numel, const size_t num_dims, - const size_t *dims, + const size_t chunk_len, const float *inp, - const size_t *inp_strides, - float *out, - const size_t *out_strides + const size_t *dims, + const size_t *strides, + float *out ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -48,10 +117,8 @@ extern "C" __global__ void min_to_forward( return; } - unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides); - unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides); - - atomicMinf(out + out_strided_i, inp[inp_strided_i]); + unsigned int inp_i = get_strided_index(i, num_dims, dims, strides); + chunk_min(numel, chunk_len, inp[inp_i], out); } // Accepts pre-broadcasted strides for both input & output. @@ -59,6 +126,7 @@ extern "C" __global__ void min_to_forward( extern "C" __global__ void min_to_backward( const size_t numel, const size_t num_dims, + const float elems_per_thread, const size_t *dims, const float *inp, float *grad_inp, @@ -67,15 +135,15 @@ extern "C" __global__ void min_to_backward( const float *grad_out, const size_t *out_strides ) { - unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= numel) { + if (inp_i >= numel) { return; } - unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides); - unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides); + unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides); + unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides); - auto tmp = inp[inp_strided_i] == out[out_strided_i] ? grad_out[out_strided_i] : 0.0; - atomicAdd(grad_inp + inp_strided_i, tmp); + auto tmp = inp[inp_i] == out[out_i] ? grad_out[out_i] : 0.0; + grad_inp[inp_i] += tmp * elems_per_thread; } diff --git a/src/tensor_ops/min_to/mod.rs b/src/tensor_ops/min_to/mod.rs index 707934842..82951d593 100644 --- a/src/tensor_ops/min_to/mod.rs +++ b/src/tensor_ops/min_to/mod.rs @@ -87,7 +87,7 @@ mod tests { use crate::tests::{assert_close, TestDevice}; #[test] - fn test_valids_min_axis() { + fn test_min_valid_axes() { let dev: TestDevice = Default::default(); let _ = dev.zeros::>().min::(); let _ = dev.zeros::>().min::, _>(); From 025d883961f0a534b3109323e9b5a616ca836cf6 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Fri, 20 Jan 2023 13:17:22 -0600 Subject: [PATCH 11/11] run cargo fmt --- src/tensor_ops/max_to/cuda_kernel.rs | 2 +- src/tensor_ops/min_to/cuda_kernel.rs | 2 +- src/tensor_ops/mod.rs | 2 +- src/tensor_ops/sum_to/cuda_kernel.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tensor_ops/max_to/cuda_kernel.rs b/src/tensor_ops/max_to/cuda_kernel.rs index a221dcb3e..ea0e6cae5 100644 --- a/src/tensor_ops/max_to/cuda_kernel.rs +++ b/src/tensor_ops/max_to/cuda_kernel.rs @@ -1,8 +1,8 @@ +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; -use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; diff --git a/src/tensor_ops/min_to/cuda_kernel.rs b/src/tensor_ops/min_to/cuda_kernel.rs index b828d14b5..a38d2a614 100644 --- a/src/tensor_ops/min_to/cuda_kernel.rs +++ b/src/tensor_ops/min_to/cuda_kernel.rs @@ -1,8 +1,8 @@ +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; -use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 49d18cf18..d597ef112 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -181,8 +181,8 @@ mod var_to; pub(crate) mod cpu_kernels; #[cfg(feature = "cuda")] pub(crate) mod cuda_kernels; -pub(crate) mod ops; mod internal_reshapes; +pub(crate) mod ops; pub use abs::abs; pub use add::{add, TryAdd}; diff --git a/src/tensor_ops/sum_to/cuda_kernel.rs b/src/tensor_ops/sum_to/cuda_kernel.rs index f4ab4160b..0c2f3b413 100644 --- a/src/tensor_ops/sum_to/cuda_kernel.rs +++ b/src/tensor_ops/sum_to/cuda_kernel.rs @@ -1,8 +1,8 @@ +use crate::tensor_ops::internal_reshapes::permute_for_reductions; use crate::{ shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape}, tensor::cuda::{Cuda, CudaArray}, }; -use crate::tensor_ops::internal_reshapes::permute_for_reductions; use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};