Skip to content

Commit

Permalink
Moving conv into device and cleaning up a bit (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Oct 3, 2022
1 parent c94f425 commit 12b3467
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 288 deletions.
179 changes: 179 additions & 0 deletions src/devices/conv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use super::Cpu;

/// **Requires nightly** 2d convolution with stride and padding specified at trait level.
///
/// This allows the rest of the parameters to be inferred by inputs.
pub trait DeviceConv2D<const S: usize, const P: usize> {
/// Forward operation that modifies the `out` image.
fn conv_forward<
const C: usize,
const O: usize,
const K: usize,
const H: usize,
const W: usize,
>(
img: &[[[f32; W]; H]; C],
weight: &[[[[f32; K]; K]; C]; O],
bias: &[f32; O],
out: &mut [[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O],
);

/// Backward operation that modifies the gradients of img, weight, and bias.
fn conv_backward<
const C: usize,
const O: usize,
const K: usize,
const H: usize,
const W: usize,
>(
img: &[[[f32; W]; H]; C],
weight: &[[[[f32; K]; K]; C]; O],
out_g: &[[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O],
img_g: &mut [[[f32; W]; H]; C],
weight_g: &mut [[[[f32; K]; K]; C]; O],
bias_g: &mut [f32; O],
);
}

impl<const S: usize, const P: usize> DeviceConv2D<S, P> for Cpu {
fn conv_forward<
const C: usize,
const O: usize,
const K: usize,
const H: usize,
const W: usize,
>(
img: &[[[f32; W]; H]; C],
weight: &[[[[f32; K]; K]; C]; O],
bias: &[f32; O],
out: &mut [[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O],
) {
let out_height = (H + 2 * P - K) / S + 1;
let out_width = (W + 2 * P - K) / S + 1;
for c in 0..C {
for oc in 0..O {
for oh in 0..out_height {
for ow in 0..out_width {
let o = &mut out[oc][oh][ow];
for k1 in 0..K {
let y = (oh * S + k1).checked_sub(P);
for k2 in 0..K {
let x = (ow * S + k2).checked_sub(P);
if let Some((y, x)) = y.zip(x) {
if y < H && x < W {
*o += weight[oc][c][k1][k2] * img[c][y][x];
}
}
}
}
}
}
}
}
for oc in 0..O {
for oh in 0..out_height {
for ow in 0..out_width {
out[oc][oh][ow] += bias[oc];
}
}
}
}

fn conv_backward<
const C: usize,
const O: usize,
const K: usize,
const H: usize,
const W: usize,
>(
img: &[[[f32; W]; H]; C],
weight: &[[[[f32; K]; K]; C]; O],
out_g: &[[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O],
img_g: &mut [[[f32; W]; H]; C],
weight_g: &mut [[[[f32; K]; K]; C]; O],
bias_g: &mut [f32; O],
) {
let out_height = (H + 2 * P - K) / S + 1;
let out_width = (W + 2 * P - K) / S + 1;

for oc in 0..O {
for oh in 0..out_height {
for ow in 0..out_width {
bias_g[oc] += out_g[oc][oh][ow];
}
}
}

for c in 0..C {
for oh in 0..out_height {
for ow in 0..out_width {
for oc in 0..O {
let o_g = &out_g[oc][oh][ow];
for k1 in 0..K {
let y = (oh * S + k1).wrapping_sub(P);
if y < H {
for k2 in 0..K {
let x = (ow * S + k2).wrapping_sub(P);
if x < W {
weight_g[oc][c][k1][k2] += img[c][y][x] * o_g;
img_g[c][y][x] += weight[oc][c][k1][k2] * o_g;
}
}
}
}
}
}
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::devices::{AllocateZeros, FillElements};
use crate::tests::assert_close;
use rand::prelude::*;
use rand_distr::StandardNormal;

#[test]
fn test_conv2d_s4p3k2() {
let mut rng = StdRng::seed_from_u64(432);
let mut randn = |x: &mut f32| *x = rng.sample(StandardNormal);

let weight: Box<[[[[f32; 2]; 2]; 5]; 3]> = Cpu::filled(&mut randn);
let bias: Box<[f32; 3]> = Cpu::filled(&mut randn);
let x: Box<[[[f32; 6]; 7]; 5]> = Cpu::filled(&mut randn);

let mut out = [[[0.0; 3]; 3]; 3];
<Cpu as DeviceConv2D<4, 3>>::conv_forward(
x.as_ref(),
weight.as_ref(),
bias.as_ref(),
&mut out,
);

#[rustfmt::skip]
assert_close(&out, &[
[[-0.57176435, -0.57176435, -0.57176435],[-0.57176435, 1.0759051, 1.4307989],[-0.57176435, -0.86296344, -1.8794353]],
[[0.29306656, 0.29306656, 0.29306656],[0.29306656, 0.9771965, 1.467767],[0.29306656, -6.367015, -2.3370528]],
[[-0.19717735, -0.19717735, -0.19717735],[-0.19717735, 1.3412137, 2.9476144],[-0.19717735, 4.247249, -2.1779637]],
]);

let mut wg: Box<[[[[f32; 2]; 2]; 5]; 3]> = Cpu::zeros();
let mut bg: Box<[f32; 3]> = Cpu::zeros();
let mut xg: Box<[[[f32; 6]; 7]; 5]> = Cpu::zeros();
<Cpu as DeviceConv2D<4, 3>>::conv_backward(
&x,
&weight,
&out,
xg.as_mut(),
wg.as_mut(),
bg.as_mut(),
);

assert_ne!(wg.as_ref(), &[[[[0.0; 2]; 2]; 5]; 3]);
assert_ne!(bg.as_ref(), &[0.0; 3]);
assert_ne!(xg.as_ref(), &[[[0.0; 6]; 7]; 5]);
}
}
5 changes: 5 additions & 0 deletions src/devices/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ pub use matmul::*;
pub use permute::*;
pub use select::*;

#[cfg(feature = "nightly")]
mod conv;
#[cfg(feature = "nightly")]
pub use conv::*;

use std::ops::*;

/// The CPU device
Expand Down
Loading

0 comments on commit 12b3467

Please sign in to comment.