Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cuda kernels for conv2d #369

Merged
merged 9 commits into from
Jan 20, 2023
Merged
109 changes: 109 additions & 0 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
struct ConvParams {
size_t channels_in;
size_t height_in;
size_t width_in;
size_t stride;
size_t padding;
size_t kernel;
size_t channels_out;
size_t height_out;
size_t width_out;
};

extern "C" __global__ void unfold_input_into_patches(
const ConvParams op,
const float *image,
const size_t *image_strides,
float *patches,
const size_t numel
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}

// patches shape is (C, K, K, height_out, width_out)
unsigned idx = i;
const size_t ow = idx % op.width_out;
idx /= op.width_out;
const size_t oh = idx % op.height_out;
idx /= op.height_out;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.channels_in;

const size_t y_plus_p = oh * op.stride + k1;
if (y_plus_p < op.padding) {
return;
}
const size_t y = y_plus_p - op.padding;

const size_t x_plus_p = ow * op.stride + k2;
if (x_plus_p < op.padding) {
return;
}
const size_t x = x_plus_p - op.padding;

if (y >= op.height_in || x >= op.width_in) {
return;
}

patches[i] = image[c * image_strides[0] + y * image_strides[1] + x * image_strides[2]];
}

extern "C" __global__ void unfold_output_into_patches(
const ConvParams op,
const float *image,
const size_t *image_strides,
float *patches,
const size_t numel
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}

// patches shape is (channels_out, K, K, height_in, width_in)
unsigned idx = i;
const size_t y = idx % op.width_in;
idx /= op.width_in;
const size_t x = idx % op.height_in;
idx /= op.height_in;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t o = idx % op.channels_out;

if (y + op.padding) < k1 {
return;
}
const size_t oh_mul_s = y + op.padding - k1;
if oh_mul_s % op.stride != 0 {
return;
}
const size_t oh = oh_mul_s / op.stride;

if (x + op.padding) < k2 {
return;
}
const size_t ow_mul_s = x + op.padding - k2;
if ow_mul_s % op.stride != 0 {
return;
}
const size_t ow = ow_mul_s / op.stride;
coreylowman marked this conversation as resolved.
Show resolved Hide resolved

patches[i] = image[o * image_strides[0] + oh * image_strides[1] + ow * image_strides[2]];
}

extern "C" __global__ void transpose_filters(

) {

}

extern "C" __global__ void sum_transposed_filters() {

}
146 changes: 143 additions & 3 deletions src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,40 @@
use crate::{shapes::*, tensor::Cuda};
use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig};

use crate::tensor_ops::matmul::cuda_kernel::sgemm;
use crate::{
shapes::*,
tensor::cuda::{Cuda, CudaArray},
};

use std::sync::Arc;

const MODULE_NAME: &str = "conv2d";
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/conv2d.ptx"));
const UNFOLD_INPUT_FN: &str = "unfold_input_into_patches";
const UNFOLD_OUTPUT_FN: &str = "unfold_output_into_patches";
const TRANSPOSE_FILTERS_FN: &str = "transpose_filters";
const SUM_TRANSPOSED_FILTERS_FN: &str = "sum_transposed_filters";
const ALL_FN_NAMES: [&str; 4] = [
UNFOLD_INPUT_FN,
UNFOLD_OUTPUT_FN,
TRANSPOSE_FILTERS_FN,
SUM_TRANSPOSED_FILTERS_FN,
];

#[repr(C)]
struct ConvParams {
channels_in: usize,
height_in: usize,
width_in: usize,
stride: usize,
padding: usize,
kernel: usize,
channels_out: usize,
height_out: usize,
width_out: usize,
}

unsafe impl AsKernelParam for ConvParams {}

impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: usize>
super::Conv2DKernel<f32, C, O, K, S, P> for Cuda
Expand All @@ -11,7 +47,65 @@ impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: us
Self::Storage<Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>, f32>,
Self::Err,
> {
todo!()
if !self.dev.has_func(MODULE_NAME, ALL_FN_NAMES[0]) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
}

let height_out = (H + 2 * P - K) / S + 1;
let width_out = (W + 2 * P - K) / S + 1;
let patches_numel = C * K * K * height_out * width_out;
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?;

let lhs_strides = self.dev.take_async(lhs.strides.into())?;

let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_INPUT_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(patches.len() as u32);
let params = (
ConvParams {
channels_in: C,
height_in: H,
width_in: W,
stride: S,
padding: P,
kernel: K,
channels_out: O,
height_out,
width_out,
},
lhs.data.as_ref(),
&lhs_strides,
&mut patches,
patches_numel,
);
unsafe { unfold_fn.launch_async(cfg, params) }?;

let shape = (Const, Const, Const);
let strides = shape.strides();
let mut storage = self.dev.alloc_zeros_async::<f32>(shape.num_elements())?;

let m = O;
let k = C * K * K;
let n = width_out * height_out;
unsafe {
sgemm(
self.blas.as_ref(),
(m, k, n),
rhs.data.as_ref(),
[k, 1],
&patches,
[n, 1],
0.0,
&mut storage,
[n, 1],
)?;
}

Ok(CudaArray {
data: Arc::new(storage),
shape,
strides,
})
}

fn backward<const H: usize, const W: usize>(
Expand All @@ -25,7 +119,53 @@ impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: us
f32,
>,
) -> Result<(), Self::Err> {
todo!()
let height_out = (H + 2 * P - K) / S + 1;
let width_out = (W + 2 * P - K) / S + 1;
let patches_numel = O * K * K * H * W;
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?;
let grad_out_strides = self.dev.take_async(grad_out.strides.into())?;

{
let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_OUTPUT_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(patches.len() as u32);
let params = (
ConvParams {
channels_in: C,
height_in: H,
width_in: W,
stride: S,
padding: P,
kernel: K,
channels_out: O,
height_out,
width_out,
},
grad_out.data.as_ref(),
&grad_out_strides,
&mut patches,
patches_numel,
);
unsafe { unfold_fn.launch_async(cfg, params) }?;
}

{
todo!("call transpose_filters");
}

{
// img_g += filters^T * unfold(grad_out)
todo!("call sgemm");
}

{
// weight_g^T += img * patches^T
todo!("allocate zeros for grad_rhs and call sgemm");
}

{
todo!("call sum_transposed_filters to add transposed filters to grad_rhs")
}
Ok(())
}
}

Expand Down
12 changes: 3 additions & 9 deletions src/tensor_ops/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod cuda_kernel;

use crate::{
gradients::Tape,
shapes::{Const, Dim, Dtype, Rank3, Rank4},
shapes::*,
tensor::{DeviceStorage, HasErr, PutTape, SplitTape, Tensor},
};

Expand Down Expand Up @@ -110,7 +110,7 @@ impl<
T: Tape<D>,
> TryConv2DTo<Tensor<Rank4<O, C, K, K>, f32, D>, S, P> for Tensor<Rank3<C, H, W>, f32, D, T>
where
Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>: Sized,
Rank2<{ (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>: Sized,
{
type Output =
Tensor<Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>, f32, D, T>;
Expand Down Expand Up @@ -151,12 +151,7 @@ impl<
> TryConv2DTo<Tensor<Rank4<O, C, K, K>, f32, D>, S, P>
for Tensor<(B, Const<C>, Const<H>, Const<W>), f32, D, T>
where
(
B,
Const<O>,
Const<{ (H + 2 * P - K) / S + 1 }>,
Const<{ (W + 2 * P - K) / S + 1 }>,
):,
Rank2<{ (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>:,
{
type Output = Tensor<
(
Expand Down Expand Up @@ -201,7 +196,6 @@ where
mod tests {
use super::*;
use crate::{
shapes::*,
tensor::*,
tensor_ops::*,
tests::{assert_close, AssertClose, TestDevice},
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/matmul/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn sgemm_config<M: Dim, K: Dim, N: Dim>(
///
/// lhs is a and rhs is b, but we have to transpose them if they are not already
#[allow(clippy::too_many_arguments)]
unsafe fn sgemm<
pub(crate) unsafe fn sgemm<
M: Dim,
K: Dim,
N: Dim,
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub(super) mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;
pub(super) mod cuda_kernel;

use crate::{
gradients::{Merge, Tape},
Expand Down