Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Adding dilation/groups to Conv2D. Adding dilation to Pool2D #767

Merged
merged 23 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 74 additions & 17 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,28 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::Conv2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::Conv2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
Conv2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
Conv2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv2D<I, O, K, S, P, E, D>;
type Built = Conv2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +56,43 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.
#[derive(Debug, Clone)]
pub struct Conv2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: DeviceStorage,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for Conv2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for Conv2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
type To<E2: Dtype, D2: Device<E2>> = Conv2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = Conv2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,23 +113,52 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for Conv2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for Conv2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConv2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
(Img, Tensor<Rank4<O, C, K, K>, E, D>): TryConv2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_conv2d_to(self.weight.clone())
type Output = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<O, C, K, K>, E, D>) as TryConv2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_conv2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for Conv2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for Conv2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: DeviceStorage,
Expand Down
64 changes: 48 additions & 16 deletions src/nn/pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#[cfg(feature = "nightly")]
use crate::tensor_ops::{ConstAvgPool2D, ConstMaxPool2D, ConstMinPool2D};
use crate::{
shapes::Const,
tensor_ops::{Pool2DKind, TryPool2D},
};

#[allow(unused)]
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
Expand All @@ -11,8 +14,14 @@ use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule};
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct AvgPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct AvgPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

/// Max pool with 2d kernel that operates on images (3d) and batches of images (4d).
/// Each patch reduces to the maximum value in that patch.
Expand All @@ -21,8 +30,14 @@ pub struct AvgPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PA
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct MaxPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct MaxPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

/// Minimum pool with 2d kernel that operates on images (3d) and batches of images (4d).
/// Each patch reduces to the minimum of the values in the patch.
Expand All @@ -31,31 +46,48 @@ pub struct MaxPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PA
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION` How dilated the kernel should be. Defaults to `1`.
#[derive(Debug, Default, Clone)]
pub struct MinPool2D<const KERNEL_SIZE: usize, const STRIDE: usize = 1, const PADDING: usize = 0>;
pub struct MinPool2D<
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
>;

macro_rules! impl_pools {
($PoolTy:tt, $Trait:ident) => {
impl<const K: usize, const S: usize, const P: usize> ZeroSizedModule for $PoolTy<K, S, P> {}
impl<const K: usize, const S: usize, const P: usize> NonMutableModule for $PoolTy<K, S, P> {}
($PoolTy:tt, $Op:expr) => {
impl<const K: usize, const S: usize, const P: usize, const L: usize> ZeroSizedModule
for $PoolTy<K, S, P, L>
{
}
impl<const K: usize, const S: usize, const P: usize, const L: usize> NonMutableModule
for $PoolTy<K, S, P, L>
{
}

#[cfg(feature = "nightly")]
impl<const K: usize, const S: usize, const P: usize, Img: $Trait<K, S, P>> Module<Img>
for $PoolTy<K, S, P>
impl<
const K: usize,
const S: usize,
const P: usize,
const L: usize,
Img: TryPool2D<Const<K>, Const<S>, Const<P>, Const<L>>,
> Module<Img> for $PoolTy<K, S, P, L>
{
type Output = Img::Output;
type Error = Img::Err;
type Output = Img::Pooled;
type Error = Img::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Img::Err> {
x.try_pool2d()
fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
x.try_pool2d($Op, Const, Const, Const, Const)
}
}
};
}

impl_pools!(AvgPool2D, ConstAvgPool2D);
impl_pools!(MaxPool2D, ConstMaxPool2D);
impl_pools!(MinPool2D, ConstMinPool2D);
impl_pools!(AvgPool2D, Pool2DKind::Avg);
impl_pools!(MaxPool2D, Pool2DKind::Max);
impl_pools!(MinPool2D, Pool2DKind::Min);

#[cfg(feature = "nightly")]
#[cfg(test)]
Expand Down
24 changes: 24 additions & 0 deletions src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ where
}
}

impl<const N: usize> core::ops::Mul<Const<N>> for usize {
type Output = usize;
fn mul(self, rhs: Const<N>) -> Self::Output {
self.size() * rhs.size()
}
}
impl<const N: usize> core::ops::Mul<usize> for Const<N> {
type Output = usize;
fn mul(self, rhs: usize) -> Self::Output {
self.size() * rhs.size()
}
}

#[cfg(feature = "nightly")]
impl<const N: usize, const M: usize> core::ops::Mul<Const<N>> for Const<M>
where
Const<{ N * M }>: Sized,
{
type Output = Const<{ N * M }>;
fn mul(self, _: Const<N>) -> Self::Output {
Const
}
}

/// Represents either `[T; N]` or `Vec<T>`
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
Expand Down
Loading