Skip to content

Commit

Permalink
Add cuda kernels for conv2d (#369)
Browse files Browse the repository at this point in the history
* #333 conv2d forward implemented

* Adding todos

* Fixing bug with conv2d cpu kernel

* Working commit of cuda kernel. backwards not passing for weight grad

* Some tests passing

* Cleanup, tests still failing

* Fixing bug with matmul with 1 dimensions
  • Loading branch information
coreylowman authored Jan 20, 2023
1 parent a45e2a1 commit b9617b1
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 93 deletions.
175 changes: 175 additions & 0 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
struct Conv2DOp {
size_t stride;
size_t padding;
size_t kernel;
size_t batch;
size_t chan_in;
size_t chan_out;
size_t h_in;
size_t h_out;
size_t w_in;
size_t w_out;
};

extern "C" __global__ void unfold_input_into_patches(
const Conv2DOp op,
const float *image, // 4d (Batch, Channels, Height, Width)
float *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const auto patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
if (i >= patches_numel) {
return;
}

// patches shape is (B, C, K, K, h_out, w_out)
unsigned int idx = i;
const size_t ow = idx % op.w_out;
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;
idx /= op.batch;

const size_t y_plus_p = oh * op.stride + k1;
if (y_plus_p < op.padding) {
return;
}
const size_t y = y_plus_p - op.padding;
if (y >= op.h_in) {
return;
}

const size_t x_plus_p = ow * op.stride + k2;
if (x_plus_p < op.padding) {
return;
}
const size_t x = x_plus_p - op.padding;
if (x >= op.w_in) {
return;
}

const size_t i_image = b * (op.chan_in * op.h_in * op.w_in) + c * (op.h_in * op.w_in) + y * (op.w_in) + x;
patches[i] = image[i_image];
}

extern "C" __global__ void unfold_output_into_patches(
const Conv2DOp op,
const float *image_out, // 4d (Batch, ChanOut, HeightOut, WidthOut)
float *patches // 6d (Batch, ChanOut, KernelSize, KernelSize, HeightIn, WidthIn)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const auto patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
if (i >= patches_numel) {
return;
}

unsigned int idx = i;
const size_t x = idx % op.w_in;
idx /= op.w_in;
const size_t y = idx % op.h_in;
idx /= op.h_in;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;
const size_t b = idx % op.batch;
idx /= op.batch;

size_t oh = y + op.padding;
if (oh < k1) {
return;
}
oh -= k1;
if (oh % op.stride != 0) {
return;
}
oh /= op.stride;
if (oh >= op.h_out) {
return;
}

size_t ow = x + op.padding;
if (ow < k2) {
return;
}
ow -= k2;
if (ow % op.stride != 0) {
return;
}
ow /= op.stride;
if (ow >= op.w_out) {
return;
}

size_t image_i = b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out) + oh * (op.w_out) + ow;
patches[i] = image_out[image_i];
}

extern "C" __global__ void transpose_and_broadcast_filters(
const Conv2DOp op,
const float *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
float *filters_tr // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
auto numel = op.chan_in * op.chan_out * op.kernel * op.kernel;
if (i >= numel) {
return;
}

unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;

const float f = filters[i];
for (auto b = 0; b < op.batch; b++) {
filters_tr[b * numel + i_tr] = f;
}
}

extern "C" __global__ void sum_transposed_filters(
const Conv2DOp op,
const float *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
float *filters // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
auto numel = op.chan_out * op.chan_in * op.kernel * op.kernel;
if (i >= numel) {
return;
}

unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t o = idx % op.chan_out;
idx /= op.chan_out;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;

float tmp = 0.0;
for (auto b = 0; b < op.batch; b++) {
tmp += filters_tr[b * numel + i_tr];
}

filters[i] += tmp;
}
2 changes: 1 addition & 1 deletion src/tensor_ops/conv2d/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl Conv2DKernel<f32> for Cpu {
+ c * rhs.strides[1]
+ k1 * rhs.strides[2]
+ k2 * rhs.strides[3];
buf[idx] = *f;
buf[idx] += *f;
}
}

Expand Down
201 changes: 148 additions & 53 deletions src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,158 @@
use crate::{shapes::*, tensor::Cuda};
use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig};

impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: usize>
super::Conv2DKernel<f32, C, O, K, S, P> for Cuda
{
fn forward<const H: usize, const W: usize>(
&self,
lhs: &Self::Storage<Rank3<C, H, W>, f32>,
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>,
) -> Result<
Self::Storage<Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>, f32>,
Self::Err,
> {
todo!()
}
use crate::tensor_ops::matmul::cuda_kernel::sgemm_batch;
use crate::{shapes::*, tensor::cuda::Cuda};

use std::sync::Arc;

const MODULE_NAME: &str = "conv2d";
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/conv2d.ptx"));
const UNFOLD_INPUT_FN: &str = "unfold_input_into_patches";
const UNFOLD_OUTPUT_FN: &str = "unfold_output_into_patches";
const BR_TR_FILTERS_FN: &str = "transpose_and_broadcast_filters";
const COLLECT_GRADS_FN: &str = "sum_transposed_filters";
const ALL_FN_NAMES: [&str; 4] = [
UNFOLD_INPUT_FN,
UNFOLD_OUTPUT_FN,
BR_TR_FILTERS_FN,
COLLECT_GRADS_FN,
];

unsafe impl AsKernelParam for super::Conv2DOp {}

fn backward<const H: usize, const W: usize>(
impl super::Conv2DKernel<f32> for Cuda {
fn forward<L: Shape, R: Shape, O: Shape>(
&self,
lhs: &Self::Storage<Rank3<C, H, W>, f32>,
grad_lhs: &mut Self::Storage<Rank3<C, H, W>, f32>,
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>,
grad_rhs: &mut Self::Storage<Rank4<O, C, K, K>, f32>,
grad_out: &Self::Storage<
Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>,
f32,
>,
op: super::Conv2DOp,
lhs: &Self::Storage<L, f32>,
rhs: &Self::Storage<R, f32>,
out: &mut Self::Storage<O, f32>,
) -> Result<(), Self::Err> {
todo!()
}
}
assert_eq!(
lhs.shape().strides(),
lhs.strides,
"Only works with contiguous image strides"
);

impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: usize>
super::Conv2DBatchedKernel<f32, C, O, K, S, P> for Cuda
{
#[rustfmt::skip]
fn forward<B: Dim, const H: usize, const W: usize>(
&self,
lhs: &Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>,
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>,
) -> Result<
Self::Storage<
(B, Const<O>, Const<{ (H + 2 * P - K) / S + 1 }>, Const<{ (W + 2 * P - K) / S + 1 }>),
f32,
>,
Self::Err,
> {
todo!()
if !self.dev.has_func(MODULE_NAME, ALL_FN_NAMES[0]) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
}

let patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?;

let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_INPUT_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(patches.len() as u32);
let params = (op, lhs.data.as_ref(), &mut patches);
unsafe { unfold_fn.launch_async(cfg, params) }?;

// (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW)
let m = op.chan_out;
let k = op.chan_in * op.kernel * op.kernel;
let n = op.h_out * op.w_out;
unsafe {
sgemm_batch(
self.blas.as_ref(),
(op.batch, m, k, n),
rhs.data.as_ref(),
[0, k, 1],
&patches,
[k * n, n, 1],
0.0,
Arc::make_mut(&mut out.data),
[m * n, n, 1],
)
.unwrap();
}

Ok(())
}
#[rustfmt::skip]
fn backward<B: Dim, const H: usize, const W: usize>(

fn backward<L: Shape, R: Shape, O: Shape>(
&self,
lhs: &Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>,
grad_lhs: &mut Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>,
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>,
grad_rhs: &mut Self::Storage<Rank4<O, C, K, K>, f32>,
grad_out: &Self::Storage<
(B, Const<O>, Const<{ (H + 2 * P - K) / S + 1 }>, Const<{ (W + 2 * P - K) / S + 1 }>),
f32,
>,
op: super::Conv2DOp,
lhs: &Self::Storage<L, f32>,
grad_lhs: &mut Self::Storage<L, f32>,
rhs: &Self::Storage<R, f32>,
grad_rhs: &mut Self::Storage<R, f32>,
grad_out: &Self::Storage<O, f32>,
) -> Result<(), Self::Err> {
todo!()
let patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in;
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?;

{
// unfold grad_out into patches
let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_OUTPUT_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(patches_numel as u32);
let params = (op, grad_out.data.as_ref(), &mut patches);
unsafe { unfold_fn.launch_async(cfg, params) }?;
}

let filters_numel = op.batch * op.chan_in * op.chan_out * op.kernel * op.kernel;
let mut f_b1023 = self.dev.alloc_zeros_async::<f32>(filters_numel)?;
let mut grad_f_b1023 = self.dev.alloc_zeros_async::<f32>(filters_numel)?;

{
// prepare filters for backward operations by
// swapping dims 0 and 1 and adding a batch dimension
let tr_fn = self.dev.get_func(MODULE_NAME, BR_TR_FILTERS_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32);
let params = (op, rhs.data.as_ref(), &mut f_b1023);
unsafe { tr_fn.launch_async(cfg, params) }?;
}

{
// img_g += filters * patches
// (B, C, H * W) += (B, C, O * K * K) * (B, O * K * K, H * W)
let m = op.chan_in;
let k = op.chan_out * op.kernel * op.kernel;
let n = op.h_in * op.w_in;
unsafe {
sgemm_batch(
self.blas.as_ref(),
(op.batch, m, k, n),
&f_b1023,
[m * k, k, 1],
&patches,
[k * n, n, 1],
1.0,
Arc::make_mut(&mut grad_lhs.data),
[m * n, n, 1],
)
.unwrap();
}
}

{
// weight_g += img * patches^T
// (B, C, O * K * K) += (B, C, H * W) * (B, H * W, O * K * K)
let m = op.chan_in;
let k = op.h_in * op.w_in;
let n = op.chan_out * op.kernel * op.kernel;
unsafe {
sgemm_batch(
self.blas.as_ref(),
(op.batch, m, k, n),
lhs.data.as_ref(),
[m * k, k, 1],
&patches,
[k * n, 1, k],
1.0,
&mut grad_f_b1023,
[m * n, n, 1],
)
.unwrap();
}

// sum all the gradients collected in our broadcasted grad_f
// into grad_rhs
let sum_fn = self.dev.get_func(MODULE_NAME, COLLECT_GRADS_FN).unwrap();
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32);
let params = (op, &grad_f_b1023, Arc::make_mut(&mut grad_rhs.data));
unsafe { sum_fn.launch_async(cfg, params) }?;
}

Ok(())
}
}
Loading

0 comments on commit b9617b1

Please sign in to comment.