-
-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* #333 conv2d forward implemented * Adding todos * Fixing bug with conv2d cpu kernel * Working commit of cuda kernel. backwards not passing for weight grad * Some tests passing * Cleanup, tests still failing * Fixing bug with matmul with 1 dimensions
- Loading branch information
1 parent
a45e2a1
commit b9617b1
Showing
5 changed files
with
366 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
struct Conv2DOp { | ||
size_t stride; | ||
size_t padding; | ||
size_t kernel; | ||
size_t batch; | ||
size_t chan_in; | ||
size_t chan_out; | ||
size_t h_in; | ||
size_t h_out; | ||
size_t w_in; | ||
size_t w_out; | ||
}; | ||
|
||
extern "C" __global__ void unfold_input_into_patches( | ||
const Conv2DOp op, | ||
const float *image, // 4d (Batch, Channels, Height, Width) | ||
float *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut) | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
const auto patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; | ||
if (i >= patches_numel) { | ||
return; | ||
} | ||
|
||
// patches shape is (B, C, K, K, h_out, w_out) | ||
unsigned int idx = i; | ||
const size_t ow = idx % op.w_out; | ||
idx /= op.w_out; | ||
const size_t oh = idx % op.h_out; | ||
idx /= op.h_out; | ||
const size_t k2 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t k1 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t c = idx % op.chan_in; | ||
idx /= op.chan_in; | ||
const size_t b = idx % op.batch; | ||
idx /= op.batch; | ||
|
||
const size_t y_plus_p = oh * op.stride + k1; | ||
if (y_plus_p < op.padding) { | ||
return; | ||
} | ||
const size_t y = y_plus_p - op.padding; | ||
if (y >= op.h_in) { | ||
return; | ||
} | ||
|
||
const size_t x_plus_p = ow * op.stride + k2; | ||
if (x_plus_p < op.padding) { | ||
return; | ||
} | ||
const size_t x = x_plus_p - op.padding; | ||
if (x >= op.w_in) { | ||
return; | ||
} | ||
|
||
const size_t i_image = b * (op.chan_in * op.h_in * op.w_in) + c * (op.h_in * op.w_in) + y * (op.w_in) + x; | ||
patches[i] = image[i_image]; | ||
} | ||
|
||
extern "C" __global__ void unfold_output_into_patches( | ||
const Conv2DOp op, | ||
const float *image_out, // 4d (Batch, ChanOut, HeightOut, WidthOut) | ||
float *patches // 6d (Batch, ChanOut, KernelSize, KernelSize, HeightIn, WidthIn) | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
const auto patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in; | ||
if (i >= patches_numel) { | ||
return; | ||
} | ||
|
||
unsigned int idx = i; | ||
const size_t x = idx % op.w_in; | ||
idx /= op.w_in; | ||
const size_t y = idx % op.h_in; | ||
idx /= op.h_in; | ||
const size_t k2 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t k1 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t o = idx % op.chan_out; | ||
idx /= op.chan_out; | ||
const size_t b = idx % op.batch; | ||
idx /= op.batch; | ||
|
||
size_t oh = y + op.padding; | ||
if (oh < k1) { | ||
return; | ||
} | ||
oh -= k1; | ||
if (oh % op.stride != 0) { | ||
return; | ||
} | ||
oh /= op.stride; | ||
if (oh >= op.h_out) { | ||
return; | ||
} | ||
|
||
size_t ow = x + op.padding; | ||
if (ow < k2) { | ||
return; | ||
} | ||
ow -= k2; | ||
if (ow % op.stride != 0) { | ||
return; | ||
} | ||
ow /= op.stride; | ||
if (ow >= op.w_out) { | ||
return; | ||
} | ||
|
||
size_t image_i = b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out) + oh * (op.w_out) + ow; | ||
patches[i] = image_out[image_i]; | ||
} | ||
|
||
extern "C" __global__ void transpose_and_broadcast_filters( | ||
const Conv2DOp op, | ||
const float *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize) | ||
float *filters_tr // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize) | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
auto numel = op.chan_in * op.chan_out * op.kernel * op.kernel; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int idx = i; | ||
const size_t k2 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t k1 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t c = idx % op.chan_in; | ||
idx /= op.chan_in; | ||
const size_t o = idx % op.chan_out; | ||
idx /= op.chan_out; | ||
|
||
auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; | ||
|
||
const float f = filters[i]; | ||
for (auto b = 0; b < op.batch; b++) { | ||
filters_tr[b * numel + i_tr] = f; | ||
} | ||
} | ||
|
||
extern "C" __global__ void sum_transposed_filters( | ||
const Conv2DOp op, | ||
const float *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize) | ||
float *filters // 4d (ChanOut, ChanIn, KernelSize, KernelSize) | ||
) { | ||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
auto numel = op.chan_out * op.chan_in * op.kernel * op.kernel; | ||
if (i >= numel) { | ||
return; | ||
} | ||
|
||
unsigned int idx = i; | ||
const size_t k2 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t k1 = idx % op.kernel; | ||
idx /= op.kernel; | ||
const size_t c = idx % op.chan_in; | ||
idx /= op.chan_in; | ||
const size_t o = idx % op.chan_out; | ||
idx /= op.chan_out; | ||
|
||
auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; | ||
|
||
float tmp = 0.0; | ||
for (auto b = 0; b < op.batch; b++) { | ||
tmp += filters_tr[b * numel + i_tr]; | ||
} | ||
|
||
filters[i] += tmp; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,158 @@ | ||
use crate::{shapes::*, tensor::Cuda}; | ||
use cudarc::driver::{AsKernelParam, LaunchAsync, LaunchConfig}; | ||
|
||
impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: usize> | ||
super::Conv2DKernel<f32, C, O, K, S, P> for Cuda | ||
{ | ||
fn forward<const H: usize, const W: usize>( | ||
&self, | ||
lhs: &Self::Storage<Rank3<C, H, W>, f32>, | ||
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>, | ||
) -> Result< | ||
Self::Storage<Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>, f32>, | ||
Self::Err, | ||
> { | ||
todo!() | ||
} | ||
use crate::tensor_ops::matmul::cuda_kernel::sgemm_batch; | ||
use crate::{shapes::*, tensor::cuda::Cuda}; | ||
|
||
use std::sync::Arc; | ||
|
||
const MODULE_NAME: &str = "conv2d"; | ||
const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/conv2d.ptx")); | ||
const UNFOLD_INPUT_FN: &str = "unfold_input_into_patches"; | ||
const UNFOLD_OUTPUT_FN: &str = "unfold_output_into_patches"; | ||
const BR_TR_FILTERS_FN: &str = "transpose_and_broadcast_filters"; | ||
const COLLECT_GRADS_FN: &str = "sum_transposed_filters"; | ||
const ALL_FN_NAMES: [&str; 4] = [ | ||
UNFOLD_INPUT_FN, | ||
UNFOLD_OUTPUT_FN, | ||
BR_TR_FILTERS_FN, | ||
COLLECT_GRADS_FN, | ||
]; | ||
|
||
unsafe impl AsKernelParam for super::Conv2DOp {} | ||
|
||
fn backward<const H: usize, const W: usize>( | ||
impl super::Conv2DKernel<f32> for Cuda { | ||
fn forward<L: Shape, R: Shape, O: Shape>( | ||
&self, | ||
lhs: &Self::Storage<Rank3<C, H, W>, f32>, | ||
grad_lhs: &mut Self::Storage<Rank3<C, H, W>, f32>, | ||
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>, | ||
grad_rhs: &mut Self::Storage<Rank4<O, C, K, K>, f32>, | ||
grad_out: &Self::Storage< | ||
Rank3<O, { (H + 2 * P - K) / S + 1 }, { (W + 2 * P - K) / S + 1 }>, | ||
f32, | ||
>, | ||
op: super::Conv2DOp, | ||
lhs: &Self::Storage<L, f32>, | ||
rhs: &Self::Storage<R, f32>, | ||
out: &mut Self::Storage<O, f32>, | ||
) -> Result<(), Self::Err> { | ||
todo!() | ||
} | ||
} | ||
assert_eq!( | ||
lhs.shape().strides(), | ||
lhs.strides, | ||
"Only works with contiguous image strides" | ||
); | ||
|
||
impl<const K: usize, const S: usize, const P: usize, const C: usize, const O: usize> | ||
super::Conv2DBatchedKernel<f32, C, O, K, S, P> for Cuda | ||
{ | ||
#[rustfmt::skip] | ||
fn forward<B: Dim, const H: usize, const W: usize>( | ||
&self, | ||
lhs: &Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>, | ||
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>, | ||
) -> Result< | ||
Self::Storage< | ||
(B, Const<O>, Const<{ (H + 2 * P - K) / S + 1 }>, Const<{ (W + 2 * P - K) / S + 1 }>), | ||
f32, | ||
>, | ||
Self::Err, | ||
> { | ||
todo!() | ||
if !self.dev.has_func(MODULE_NAME, ALL_FN_NAMES[0]) { | ||
self.dev | ||
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?; | ||
} | ||
|
||
let patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out; | ||
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?; | ||
|
||
let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_INPUT_FN).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(patches.len() as u32); | ||
let params = (op, lhs.data.as_ref(), &mut patches); | ||
unsafe { unfold_fn.launch_async(cfg, params) }?; | ||
|
||
// (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW) | ||
let m = op.chan_out; | ||
let k = op.chan_in * op.kernel * op.kernel; | ||
let n = op.h_out * op.w_out; | ||
unsafe { | ||
sgemm_batch( | ||
self.blas.as_ref(), | ||
(op.batch, m, k, n), | ||
rhs.data.as_ref(), | ||
[0, k, 1], | ||
&patches, | ||
[k * n, n, 1], | ||
0.0, | ||
Arc::make_mut(&mut out.data), | ||
[m * n, n, 1], | ||
) | ||
.unwrap(); | ||
} | ||
|
||
Ok(()) | ||
} | ||
#[rustfmt::skip] | ||
fn backward<B: Dim, const H: usize, const W: usize>( | ||
|
||
fn backward<L: Shape, R: Shape, O: Shape>( | ||
&self, | ||
lhs: &Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>, | ||
grad_lhs: &mut Self::Storage<(B, Const<C>, Const<H>, Const<W>), f32>, | ||
rhs: &Self::Storage<Rank4<O, C, K, K>, f32>, | ||
grad_rhs: &mut Self::Storage<Rank4<O, C, K, K>, f32>, | ||
grad_out: &Self::Storage< | ||
(B, Const<O>, Const<{ (H + 2 * P - K) / S + 1 }>, Const<{ (W + 2 * P - K) / S + 1 }>), | ||
f32, | ||
>, | ||
op: super::Conv2DOp, | ||
lhs: &Self::Storage<L, f32>, | ||
grad_lhs: &mut Self::Storage<L, f32>, | ||
rhs: &Self::Storage<R, f32>, | ||
grad_rhs: &mut Self::Storage<R, f32>, | ||
grad_out: &Self::Storage<O, f32>, | ||
) -> Result<(), Self::Err> { | ||
todo!() | ||
let patches_numel = op.batch * op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in; | ||
let mut patches = self.dev.alloc_zeros_async::<f32>(patches_numel)?; | ||
|
||
{ | ||
// unfold grad_out into patches | ||
let unfold_fn = self.dev.get_func(MODULE_NAME, UNFOLD_OUTPUT_FN).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(patches_numel as u32); | ||
let params = (op, grad_out.data.as_ref(), &mut patches); | ||
unsafe { unfold_fn.launch_async(cfg, params) }?; | ||
} | ||
|
||
let filters_numel = op.batch * op.chan_in * op.chan_out * op.kernel * op.kernel; | ||
let mut f_b1023 = self.dev.alloc_zeros_async::<f32>(filters_numel)?; | ||
let mut grad_f_b1023 = self.dev.alloc_zeros_async::<f32>(filters_numel)?; | ||
|
||
{ | ||
// prepare filters for backward operations by | ||
// swapping dims 0 and 1 and adding a batch dimension | ||
let tr_fn = self.dev.get_func(MODULE_NAME, BR_TR_FILTERS_FN).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32); | ||
let params = (op, rhs.data.as_ref(), &mut f_b1023); | ||
unsafe { tr_fn.launch_async(cfg, params) }?; | ||
} | ||
|
||
{ | ||
// img_g += filters * patches | ||
// (B, C, H * W) += (B, C, O * K * K) * (B, O * K * K, H * W) | ||
let m = op.chan_in; | ||
let k = op.chan_out * op.kernel * op.kernel; | ||
let n = op.h_in * op.w_in; | ||
unsafe { | ||
sgemm_batch( | ||
self.blas.as_ref(), | ||
(op.batch, m, k, n), | ||
&f_b1023, | ||
[m * k, k, 1], | ||
&patches, | ||
[k * n, n, 1], | ||
1.0, | ||
Arc::make_mut(&mut grad_lhs.data), | ||
[m * n, n, 1], | ||
) | ||
.unwrap(); | ||
} | ||
} | ||
|
||
{ | ||
// weight_g += img * patches^T | ||
// (B, C, O * K * K) += (B, C, H * W) * (B, H * W, O * K * K) | ||
let m = op.chan_in; | ||
let k = op.h_in * op.w_in; | ||
let n = op.chan_out * op.kernel * op.kernel; | ||
unsafe { | ||
sgemm_batch( | ||
self.blas.as_ref(), | ||
(op.batch, m, k, n), | ||
lhs.data.as_ref(), | ||
[m * k, k, 1], | ||
&patches, | ||
[k * n, 1, k], | ||
1.0, | ||
&mut grad_f_b1023, | ||
[m * n, n, 1], | ||
) | ||
.unwrap(); | ||
} | ||
|
||
// sum all the gradients collected in our broadcasted grad_f | ||
// into grad_rhs | ||
let sum_fn = self.dev.get_func(MODULE_NAME, COLLECT_GRADS_FN).unwrap(); | ||
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32); | ||
let params = (op, &grad_f_b1023, Arc::make_mut(&mut grad_rhs.data)); | ||
unsafe { sum_fn.launch_async(cfg, params) }?; | ||
} | ||
|
||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.