Skip to content

Commit

Permalink
[Breaking] Adding dilation/groups to Conv2D. Adding dilation to Pool2D (
Browse files Browse the repository at this point in the history
coreylowman#767)

* Forward pass for CPU working with dilation/groups

* Adding test generation script

* Backward img grad working

* Backward filters working

* Removing large test case

* Adding sketches of cuda/cudnn kernels

* Update cudnn code

* Temp commit

* Update filters transpose logic for cuda

* Cuda kernel implementation

* batched conv2d tests

* Fixing dilation

* Cudnn passing tests

* cuda tests passing

* Reverting numpy tests

* Removing old conv2d implementation

* Pool2D rewrite

* Updates for pool2d cuda

* Adding dilated pool2d test

* Adding dilation & groups to nn layer

* Adding conv2d dilation & group tests

* Fixing tests & warning
  • Loading branch information
coreylowman authored and mattjurenka committed May 13, 2023
1 parent 43fe597 commit 5d3fc0e
Show file tree
Hide file tree
Showing 16 changed files with 1,357 additions and 1,268 deletions.
91 changes: 74 additions & 17 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,28 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::Conv2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::Conv2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
Conv2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
Conv2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv2D<I, O, K, S, P, E, D>;
type Built = Conv2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +56,43 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.
#[derive(Debug, Clone)]
pub struct Conv2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: DeviceStorage,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for Conv2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for Conv2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
type To<E2: Dtype, D2: Device<E2>> = Conv2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = Conv2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,23 +113,52 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for Conv2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for Conv2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConv2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
(Img, Tensor<Rank4<O, C, K, K>, E, D>): TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_conv2d_to(self.weight.clone())
type Output = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_conv2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for Conv2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for Conv2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: DeviceStorage,
Expand Down
64 changes: 48 additions & 16 deletions src/nn/pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#[cfg(feature = "nightly")]
use crate::tensor_ops::{ConstAvgPool2D, ConstMaxPool2D, ConstMinPool2D};
use crate::{
shapes::Const,
tensor_ops::{Pool2DKind, TryPool2D},
};

#[allow(unused)]
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
Expand All @@ -11,8 +14,14 @@ use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct AvgPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct AvgPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

/// Max pool with 2d kernel that operates on images (3d) and batches of images (4d).
/// Each patch reduces to the maximum value in that patch.
Expand All @@ -21,8 +30,14 @@ pub struct AvgPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PA
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct MaxPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct MaxPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

/// Minimum pool with 2d kernel that operates on images (3d) and batches of images (4d).
/// Each patch reduces to the minimum of the values in the patch.
Expand All @@ -31,31 +46,48 @@ pub struct MaxPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PA
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct MinPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct MinPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

macro_rules! impl_pools {
($PoolTy:tt, $Trait:ident) => {
impl<const K: usize, const S: usize, const P: usize> ZeroSizedModule for $PoolTy<K, S, P> {}
impl<const K: usize, const S: usize, const P: usize> NonMutableModule for $PoolTy<K, S, P> {}
($PoolTy:tt, $Op:expr) => {
impl<const K: usize, const S: usize, const P: usize, const L: usize> ZeroSizedModule
for $PoolTy<K, S, P, L>
{
}
impl<const K: usize, const S: usize, const P: usize, const L: usize> NonMutableModule
for $PoolTy<K, S, P, L>
{
}

#[cfg(feature = "nightly")]
impl<const K: usize, const S: usize, const P: usize, Img: $Trait<K, S, P>> Module<Img>
for $PoolTy<K, S, P>
impl<
const K: usize,
const S: usize,
const P: usize,
const L: usize,
Img: TryPool2D<Const<K>, Const<S>, Const<P>, Const<L>>,
> Module<Img> for $PoolTy<K, S, P, L>
{
type Output = Img::Output;
type Error = Img::Err;
type Output = Img::Pooled;
type Error = Img::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Img::Err> {
x.try_pool2d()
fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
x.try_pool2d($Op, Const, Const, Const, Const)
}
}
};
}

impl_pools!(AvgPool2D, ConstAvgPool2D);
impl_pools!(MaxPool2D, ConstMaxPool2D);
impl_pools!(MinPool2D, ConstMinPool2D);
impl_pools!(AvgPool2D, Pool2DKind::Avg);
impl_pools!(MaxPool2D, Pool2DKind::Max);
impl_pools!(MinPool2D, Pool2DKind::Min);

#[cfg(feature = "nightly")]
#[cfg(test)]
Expand Down
24 changes: 24 additions & 0 deletions src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ where
}
}

impl<const N: usize> core::ops::Mul<Const<N>> for usize {
type Output = usize;
fn mul(self, rhs: Const<N>) -> Self::Output {
self.size() * rhs.size()
}
}
impl<const N: usize> core::ops::Mul<usize> for Const<N> {
type Output = usize;
fn mul(self, rhs: usize) -> Self::Output {
self.size() * rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Mul<Const<N>> for Const<M>
where
Const<{ N * M }>: Sized,
{
type Output = Const<{ N * M }>;
fn mul(self, _: Const<N>) -> Self::Output {
Const
}
}

/// Represents either `[T; N]` or `Vec<T>`
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
Expand Down
Loading

0 comments on commit 5d3fc0e

Please sign in to comment.