Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using image/filter stride in cuda kernel for conv #495

Merged
merged 1 commit into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ template<typename T>
__device__ void unfold_input_into_patches(
const Conv2DOp op,
const T *image, // 4d (Batch, Channels, Height, Width)
const size_t *strides, // 4d image strides
T *patches // 6d (Batch, Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -56,7 +57,7 @@ __device__ void unfold_input_into_patches(
return;
}

const size_t i_image = b * (op.chan_in * op.h_in * op.w_in) + c * (op.h_in * op.w_in) + y * (op.w_in) + x;
const size_t i_image = b * strides[0] + c * strides[1] + y * strides[2] + x * strides[3];
patches[i] = image[i_image];
}

Expand Down Expand Up @@ -120,6 +121,7 @@ template<typename T>
__device__ void transpose_and_broadcast_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -139,8 +141,9 @@ __device__ void transpose_and_broadcast_filters(
idx /= op.chan_out;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

const T f = filters[i];
const T f = filters[i_no];
for (auto b = 0; b < op.batch; b++) {
filters_tr[b * numel + i_tr] = f;
}
Expand All @@ -150,7 +153,8 @@ template<typename T>
__device__ void sum_transposed_filters(
const Conv2DOp op,
const T *filters_tr, // 5d (Batch, ChanIn, ChanOut, KernelSize, KernelSize)
T *filters // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const size_t *strides // 4d filter strides
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
auto numel = op.chan_out * op.chan_in * op.kernel * op.kernel;
Expand All @@ -169,22 +173,24 @@ __device__ void sum_transposed_filters(
idx /= op.chan_out;

auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2;
auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];

T tmp = 0.0;
for (auto b = 0; b < op.batch; b++) {
tmp += filters_tr[b * numel + i_tr];
}

filters[i] += tmp;
filters[i_no] += tmp;
}

#define CONV_OP(TYPENAME, UNFOLD_INPUT, UNFOLD_OUTPUT, TR_FILTERS, SUM_TR_FILTERS) \
extern "C" __global__ void UNFOLD_INPUT( \
const Conv2DOp op, \
const TYPENAME *image, \
const size_t *strides, \
TYPENAME *patches \
) { \
unfold_input_into_patches(op, image, patches); \
unfold_input_into_patches(op, image, strides, patches); \
} \
extern "C" __global__ void UNFOLD_OUTPUT( \
const Conv2DOp op, \
Expand All @@ -196,16 +202,18 @@ extern "C" __global__ void UNFOLD_OUTPUT( \
extern "C" __global__ void TR_FILTERS( \
const Conv2DOp op, \
const TYPENAME *filters, \
const size_t *strides, \
TYPENAME *filters_tr \
) { \
transpose_and_broadcast_filters(op, filters, filters_tr); \
transpose_and_broadcast_filters(op, filters, strides, filters_tr); \
} \
extern "C" __global__ void SUM_TR_FILTERS( \
const Conv2DOp op, \
const TYPENAME *filters_tr, \
TYPENAME *filters \
TYPENAME *filters, \
const size_t *strides \
) { \
sum_transposed_filters(op, filters_tr, filters); \
sum_transposed_filters(op, filters_tr, filters, strides); \
}

CONV_OP(
Expand Down
28 changes: 18 additions & 10 deletions src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ impl HasCudaKernel<f64> for Cuda {
];
}

fn make_4d<S: Shape>(strides: S::Concrete) -> [usize; 4] {
match S::NUM_DIMS {
3 => [0, strides[0], strides[1], strides[2]],
4 => [strides[0], strides[1], strides[2], strides[3]],
_ => unreachable!("Only implemented for 3d & 4d arrays"),
}
}

impl<E: Dtype + ValidAsZeroBits> super::Conv2DKernel<E> for Cuda
where
Self: HasCudaKernel<E>,
Expand All @@ -47,22 +55,16 @@ where
rhs: &Self::Storage<R, E>,
out: &mut Self::Storage<O, E>,
) -> Result<(), Self::Err> {
assert_eq!(
lhs.shape().strides(),
lhs.strides,
"Only works with contiguous image strides"
);

if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}

let patches_numel = op.batch * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out;
let mut patches = self.dev.alloc_zeros_async::<E>(patches_numel)?;

let img_strides = self.dev.take_async(make_4d::<L>(lhs.strides).into())?;
let unfold_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = LaunchConfig::for_num_elems(patches.len() as u32);
let params = (op, lhs.data.as_ref(), &mut patches);
let params = (op, lhs.data.as_ref(), &img_strides, &mut patches);
unsafe { unfold_fn.launch_async(cfg, params) }?;

// (O, C * K * K) * (B, C * K * K, OH * OW) = (B, O, OH * OW)
Expand Down Expand Up @@ -110,13 +112,14 @@ where
let filters_numel = op.batch * op.chan_in * op.chan_out * op.kernel * op.kernel;
let mut f_b1023 = self.dev.alloc_zeros_async::<E>(filters_numel)?;
let mut grad_f_b1023 = self.dev.alloc_zeros_async::<E>(filters_numel)?;
let f_strides = self.dev.take_async(rhs.strides.into())?;

{
// prepare filters for backward operations by
// swapping dims 0 and 1 and adding a batch dimension
let tr_fn = self.dev.get_func(Self::MOD, Self::FNS[2]).unwrap();
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32);
let params = (op, rhs.data.as_ref(), &mut f_b1023);
let params = (op, rhs.data.as_ref(), &f_strides, &mut f_b1023);
unsafe { tr_fn.launch_async(cfg, params) }?;
}

Expand Down Expand Up @@ -167,7 +170,12 @@ where
// into grad_rhs
let sum_fn = self.dev.get_func(Self::MOD, Self::FNS[3]).unwrap();
let cfg = LaunchConfig::for_num_elems(rhs.shape.num_elements() as u32);
let params = (op, &grad_f_b1023, Arc::make_mut(&mut grad_rhs.data));
let params = (
op,
&grad_f_b1023,
Arc::make_mut(&mut grad_rhs.data),
&f_strides,
);
unsafe { sum_fn.launch_async(cfg, params) }?;
}

Expand Down