Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient cuda kernels for reductions #382

Merged
merged 12 commits into from
Jan 23, 2023
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ pub(crate) mod tests {
pub fn assert_close<T: AssertClose + std::fmt::Debug>(a: &T, b: &T) {
a.assert_close(b, TOLERANCE);
}

pub fn assert_close_with_tolerance<T: AssertClose + std::fmt::Debug>(
a: &T,
b: &T,
tolerance: f32,
) {
a.assert_close(b, tolerance);
}
}

/// Used to assert things about const generics
Expand Down
31 changes: 31 additions & 0 deletions src/tensor_ops/internal_reshapes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#[cfg(feature = "cuda")]
use crate::prelude::Axes;
#[cfg(feature = "cuda")]
use std::vec::Vec;

/// Moves all axes in Ax to the end of dims and strides and removes broadcasted dimensions
/// so that a cuda kernel called for each physical element of the input tensor will place elements
/// to be reduced with each other next to each other in memory.
coreylowman marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(feature = "cuda")]
pub(super) fn permute_for_reductions<I, Ax: Axes>(dims: I, strides: I) -> (Vec<usize>, Vec<usize>)
where
I: IntoIterator<Item = usize>,
{
let mut tmp = dims
.into_iter()
.zip(strides.into_iter())
.map(|x| (false, x))
.collect::<Vec<_>>();

for i in Ax::as_array().into_iter() {
tmp[i as usize].0 = true;
}

// requires stable sorting to keep non-summed axes in the correct order
tmp.sort_by_key(|x| x.0);

tmp.into_iter()
.map(|(_, x)| x)
.filter(|(_, stride)| *stride != 0)
.unzip()
}
38 changes: 22 additions & 16 deletions src/tensor_ops/max_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::tensor_ops::internal_reshapes::permute_for_reductions;
use crate::{
shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape},
tensor::cuda::{Cuda, CudaArray},
Expand Down Expand Up @@ -38,21 +39,22 @@ impl super::MaxReduceKernel<f32> for Cuda {

let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?;
let out_strides = BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;
let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides);
let dims: CudaSlice<usize> = self.dev.take_async(dims)?;
let strides: CudaSlice<usize> = self.dev.take_async(strides)?;

let physical_numel = inp.data.len();
let chunk_len = physical_numel / dst.num_elements();

let inp_numel = inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
physical_numel, // const size_t numel,
dims.len(), // const size_t num_dims,
chunk_len, // const size_t chunk_len,
inp.data.as_ref(), // const float *inp,
&inp_strides, // const size_t *inp_strides,
&mut storage, // float *out,
&out_strides, // const size_t *out_strides
&dims, // const size_t *dims,
&strides, // const size_t *strides,
&mut storage, // float *out
);
unsafe { fwd_fn.launch_async(cfg, params) }?;
Ok(CudaArray {
Expand Down Expand Up @@ -80,11 +82,15 @@ impl super::MaxReduceKernel<f32> for Cuda {
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&grad_out.shape, grad_out.strides);
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = grad_inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let physical_numel = grad_inp.data.len();
let virtual_numel = grad_inp.shape.num_elements();
let elems_per_thread = (virtual_numel / physical_numel) as f32;

let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
coreylowman marked this conversation as resolved.
Show resolved Hide resolved
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
physical_numel, // const size_t numel,
Src::NUM_DIMS, // const size_t num_dims,
elems_per_thread as f32, // const float elems_per_thread,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
Expand Down
100 changes: 84 additions & 16 deletions src/tensor_ops/max_to/max_to.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ __device__ unsigned int get_strided_index(
return strided_i;
}

__device__ unsigned int get_unstrided_index(
const unsigned int strided_i,
const size_t num_dims,
const size_t *dims,
const size_t *strides
) {
unsigned int idx = 0;
for (unsigned int d = 0; d < num_dims; d++) {
idx *= dims[d];
idx += strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d];
}
return idx;
}

extern "C" __global__ void fill_with(float *buf, float value, const size_t numel) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
Expand All @@ -31,34 +45,88 @@ extern "C" __global__ void fill_with(float *buf, float value, const size_t numel
buf[i] = value;
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v++;
return v;
}

// Efficiently computes the max of each chunk in "data" of size chunk_len, and
// stores the maximums in out[i / chunk_len]
__device__ void chunk_max(
const size_t numel,
const size_t chunk_len,
const float data,
float* out
) {
__shared__ float buf[1024];
// assumes that threads where i >= numel have already exited
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int block_i = threadIdx.x;
buf[block_i] = data;
coreylowman marked this conversation as resolved.
Show resolved Hide resolved

unsigned int chunk_i = i % chunk_len;
unsigned int chunk_start = max((int)(block_i - chunk_i), 0);
unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x);

chunk_i = block_i - chunk_start;

size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x);
size_t incr = next_power_of_two(max_chunk_len) >> 1;

__syncthreads();

// Uses sequential addressing as discussed in
// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
for (; incr > 0; incr >>= 1) {
unsigned int block_i_2 = block_i + incr;

if (block_i_2 < chunk_end && chunk_i < incr) {
// This is sound because __syncthreads and the conditions above
// ensure that no data races occur
buf[block_i] = fmaxf(buf[block_i], buf[block_i_2]);
}

__syncthreads();
}

if (block_i == chunk_start) {
atomicMaxf(out + i / chunk_len, buf[block_i]);
}
}

// strides and dims specify how to index inp to put all summed elements next to
// each other, and chunk_len is len(inp) / len(out)
extern "C" __global__ void max_to_forward(
const size_t numel,
const size_t num_dims,
const size_t *dims,
const size_t chunk_len,
const float *inp,
const size_t *inp_strides,
float *out,
const size_t *out_strides
const size_t *dims,
const size_t *strides,
float *out
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
return;
}

unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);

atomicMaxf(out + out_strided_i, inp[inp_strided_i]);
unsigned int inp_i = get_strided_index(i, num_dims, dims, strides);
chunk_max(numel, chunk_len, inp[inp_i], out);
}

// Accepts pre-broadcasted strides for both input & output.
// So both inp & out are expected to be broadcasted to the same size.
extern "C" __global__ void max_to_backward(
const size_t numel,
const size_t num_dims,
const float elems_per_thread,
const size_t *dims,
const float *inp,
float *grad_inp,
Expand All @@ -67,15 +135,15 @@ extern "C" __global__ void max_to_backward(
const float *grad_out,
const size_t *out_strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int inp_i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= numel) {
if (inp_i >= numel) {
return;
}

unsigned int inp_strided_i = get_strided_index(i, num_dims, dims, inp_strides);
unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);
unsigned int i = get_unstrided_index(inp_i, num_dims, dims, inp_strides);
unsigned int out_i = get_strided_index(i, num_dims, dims, out_strides);

auto tmp = inp[inp_strided_i] == out[out_strided_i] ? grad_out[out_strided_i] : 0.0;
atomicAdd(grad_inp + inp_strided_i, tmp);
auto tmp = inp[inp_i] == out[out_i] ? grad_out[out_i] : 0.0;
grad_inp[inp_i] += tmp * elems_per_thread;
}
2 changes: 1 addition & 1 deletion src/tensor_ops/max_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ mod tests {
use crate::tests::{assert_close, TestDevice};

#[test]
fn test_valids_max_axis() {
fn test_max_valid_axes() {
let dev: TestDevice = Default::default();
let _ = dev.zeros::<Rank1<5>>().max::<Rank0, _>();
let _ = dev.zeros::<Rank2<5, 3>>().max::<Rank1<3>, _>();
Expand Down
38 changes: 22 additions & 16 deletions src/tensor_ops/min_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::tensor_ops::internal_reshapes::permute_for_reductions;
use crate::{
shapes::{Axes, BroadcastStridesTo, ReduceShapeTo, Shape},
tensor::cuda::{Cuda, CudaArray},
Expand Down Expand Up @@ -38,21 +39,22 @@ impl super::MinReduceKernel<f32> for Cuda {

let fwd_fn = self.dev.get_func(MODULE_NAME, FWD_FN_NAME).unwrap();

let dims: CudaSlice<usize> = self.dev.take_async(inp.shape.concrete().into())?;
let inp_strides: CudaSlice<usize> = self.dev.take_async(inp.strides.into())?;
let out_strides = BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;
let (dims, strides) = permute_for_reductions::<_, Ax>(inp.shape.concrete(), inp.strides);
let dims: CudaSlice<usize> = self.dev.take_async(dims)?;
let strides: CudaSlice<usize> = self.dev.take_async(strides)?;

let physical_numel = inp.data.len();
let chunk_len = physical_numel / dst.num_elements();

let inp_numel = inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
&dims, // const size_t *dims,
physical_numel, // const size_t numel,
dims.len(), // const size_t num_dims,
chunk_len, // const size_t chunk_len,
inp.data.as_ref(), // const float *inp,
&inp_strides, // const size_t *inp_strides,
&mut storage, // float *out,
&out_strides, // const size_t *out_strides
&dims, // const size_t *dims,
&strides, // const size_t *strides,
&mut storage, // float *out
);
unsafe { fwd_fn.launch_async(cfg, params) }?;
Ok(CudaArray {
Expand Down Expand Up @@ -80,11 +82,15 @@ impl super::MinReduceKernel<f32> for Cuda {
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&grad_out.shape, grad_out.strides);
let out_strides: CudaSlice<usize> = self.dev.take_async(out_strides.into())?;

let inp_numel = grad_inp.shape.num_elements();
let cfg = LaunchConfig::for_num_elems(inp_numel as u32);
let physical_numel = grad_inp.data.len();
let virtual_numel = grad_inp.shape.num_elements();
let elems_per_thread = (virtual_numel / physical_numel) as f32;

let cfg = LaunchConfig::for_num_elems(physical_numel as u32);
let params = (
inp_numel, // size_t numel,
Src::NUM_DIMS, // size_t num_dims,
physical_numel, // const size_t numel,
Src::NUM_DIMS, // const size_t num_dims,
elems_per_thread as f32, // const float elems_per_thread,
&dims, // const size_t *dims,
inp.data.as_ref(), // const float *inp,
Arc::make_mut(&mut grad_inp.data), // float *grad_inp,
Expand Down
Loading