-
-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upscale2D and ConvTrans2d #603
Merged
coreylowman
merged 39 commits into
coreylowman:main
from
opfromthestart:conv2dtranspose
Mar 30, 2023
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
f1ce9b5
PReLU forward works-ish
opfromthestart 068183c
Doctest works now
opfromthestart ef62baa
Merge branch 'coreylowman:main' into main
opfromthestart f2fe0c4
Better-ish implementation
opfromthestart 94f7a20
Merge branch 'main' of github.com:opfromthestart/dfdx into main
opfromthestart f3ef318
Now has actual layers to use
opfromthestart b132fcb
fmt, clippy, tests
opfromthestart a42e0d2
Cleaning
opfromthestart c248c7e
Remove one unneeded generic
opfromthestart 4caa8d8
Cuda maybe working (idk)
opfromthestart d9c1346
LeakyReLU now should work
opfromthestart 641f429
Fix test
opfromthestart 22ef1f9
Merge branch 'main' of github.com:coreylowman/dfdx into real_origin/head
opfromthestart 8cf7ac0
Base things
opfromthestart 7d11365
Upscale operation now added
opfromthestart 92d92c6
test with uneven upsize
opfromthestart f71b82e
Interpolation is now a trait bound (fmt)
opfromthestart 2bc6c82
Added test for convtrans2d
opfromthestart a3bb955
Convtrans tests and impls
opfromthestart 14307e7
ConvTrans and Upscale modules
opfromthestart 9778a5d
Format & 4d
opfromthestart 2bab567
Upscale binilear works
opfromthestart 1c4d081
Fix of things not being found
opfromthestart a9e1838
put imports under flag
opfromthestart 95b5ad5
Fix 1
opfromthestart 17b1803
Fix 2
opfromthestart 872357a
Fix 3
opfromthestart 31c207b
Fix 4
opfromthestart 45cb9f6
Fixes for the rest
opfromthestart b635dfe
Formatting
opfromthestart c2ddb6f
Merge branch 'real_origin/head' into conv2dtranspose
opfromthestart 0d01b52
Merge branch 'coreylowman:main' into conv2dtranspose
opfromthestart 40b56ad
Fix no-std Vec
opfromthestart c6b0266
Clippy fixes and own error
opfromthestart 3a5d35b
fmt
opfromthestart 69e93e8
Another clippy fix
opfromthestart 2424b57
remove prelu
opfromthestart 6bee6f7
unused imports
opfromthestart d9397b8
Forating
opfromthestart File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
use num_traits::Float; | ||
use rand_distr::uniform::SampleUniform; | ||
|
||
use crate::{shapes::*, tensor::*, tensor_ops::*}; | ||
|
||
use super::*; | ||
|
||
pub mod builder { | ||
#[derive(Debug)] | ||
pub struct ConvTrans2D< | ||
const IN_CHAN: usize, | ||
const OUT_CHAN: usize, | ||
const KERNEL_SIZE: usize, | ||
const STRIDE: usize = 1, | ||
const PADDING: usize = 0, | ||
>; | ||
} | ||
|
||
impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D> | ||
BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P> | ||
where | ||
E: Dtype, | ||
D: Device<E>, | ||
ConvTrans2D<I, O, K, S, P, E, D>: BuildModule<D, E>, | ||
{ | ||
type Built = ConvTrans2D<I, O, K, S, P, E, D>; | ||
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> { | ||
Self::Built::try_build(device) | ||
} | ||
} | ||
|
||
/// **Requires Nightly** Performs *unbiased* 2d deconvolutions on 3d and 4d images. | ||
/// | ||
/// **Pytorch Equivalent**: `torch.nn.ConvTranspose2d(..., bias=False)` | ||
/// | ||
/// To create a biased conv, combine with [crate::nn::modules::Bias2D]: | ||
/// ```ignore | ||
/// # use dfdx::prelude::*; | ||
/// type BiasedConv = (ConvTrans2D<3, 5, 4>, Bias2D<5>); | ||
/// ``` | ||
/// | ||
/// Generics: | ||
/// - `IN_CHAN`: The number of input channels in an image. | ||
/// - `OUT_CHAN`: The number of channels in the output of the layer. | ||
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images. | ||
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1` | ||
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`. | ||
#[derive(Debug, Clone)] | ||
pub struct ConvTrans2D< | ||
const IN_CHAN: usize, | ||
const OUT_CHAN: usize, | ||
const KERNEL_SIZE: usize, | ||
const STRIDE: usize, | ||
const PADDING: usize, | ||
E: Dtype, | ||
D: DeviceStorage, | ||
> { | ||
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>, | ||
} | ||
|
||
impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D> | ||
TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, E, D> | ||
where | ||
E: Dtype + Float + SampleUniform, | ||
D: Device<E>, | ||
{ | ||
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, E2, D2>; | ||
|
||
fn iter_tensors<V: ModuleVisitor<Self, E, D>>( | ||
visitor: &mut V, | ||
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> { | ||
visitor.visit_fields( | ||
Self::tensor( | ||
"weight", | ||
|s| &s.weight, | ||
|s| &mut s.weight, | ||
TensorOptions::reset_with(|t| { | ||
let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); | ||
t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) | ||
}), | ||
), | ||
|weight| ConvTrans2D { weight }, | ||
) | ||
} | ||
} | ||
|
||
#[cfg(feature = "nightly")] | ||
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img> | ||
Module<Img> for ConvTrans2D<C, O, K, S, P, E, D> | ||
where | ||
E: Dtype, | ||
D: Device<E>, | ||
Img: TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>, | ||
{ | ||
type Output = Img::Output; | ||
type Error = D::Err; | ||
|
||
fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> { | ||
x.try_convtrans2d_to(self.weight.clone()) | ||
} | ||
} | ||
|
||
impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D> | ||
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D> | ||
where | ||
E: Dtype, | ||
D: DeviceStorage, | ||
{ | ||
} | ||
|
||
#[cfg(feature = "nightly")] | ||
#[cfg(test)] | ||
mod tests { | ||
use crate::{ | ||
optim::*, | ||
tensor::{AsArray, SampleTensor, ZerosTensor}, | ||
tests::*, | ||
}; | ||
|
||
use super::{builder::ConvTrans2D, *}; | ||
|
||
#[rustfmt::skip] | ||
#[test] | ||
fn test_forward_3d_sizes() { | ||
let dev: TestDevice = Default::default(); | ||
let x = dev.zeros::<Rank3<3, 8, 8>>(); | ||
let _: Tensor<Rank3<2, 10, 10>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<4, 10, 10>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<4, 9, 9>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<4, 11, 11>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 4>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<2, 17, 17>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<2, 24, 24>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<2, 8, 8>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<2, 6, 6>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank3<2, 13, 13>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone()); | ||
} | ||
|
||
#[rustfmt::skip] | ||
#[test] | ||
fn test_forward_4d_sizes() { | ||
let dev: TestDevice = Default::default(); | ||
let x = dev.zeros::<Rank4<5, 3, 8, 8>>(); | ||
let _: Tensor<Rank4<5, 2, 10, 10>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 4, 10, 10>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 4, 9, 9>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 4, 11, 11>, _, _, _> = dev.build_module::<ConvTrans2D<3, 4, 4>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 2, 17, 17>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 2, 24, 24>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 3>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 2, 8, 8>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 2, 6, 6>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone()); | ||
let _: Tensor<Rank4<5, 2, 13, 13>, _, _, _> = dev.build_module::<ConvTrans2D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone()); | ||
} | ||
|
||
#[test] | ||
fn test_2_conv_sizes() { | ||
let dev = Cpu::default(); | ||
type A = ConvTrans2D<4, 2, 3>; | ||
type B = ConvTrans2D<2, 1, 3>; | ||
let _: Tensor<Rank3<1, 10, 10>, _, _> = dev | ||
.build_module::<(A, B), TestDtype>() | ||
.forward(dev.zeros::<Rank3<4, 6, 6>>()); | ||
} | ||
|
||
#[test] | ||
fn test_3_conv_sizes() { | ||
type A = ConvTrans2D<2, 1, 3>; | ||
type B = ConvTrans2D<4, 2, 3>; | ||
type C = ConvTrans2D<1, 4, 1, 1, 1>; | ||
|
||
let dev = Cpu::default(); | ||
let _: Tensor<Rank3<1, 10, 10>, _, _> = dev | ||
.build_module::<(C, B, A), TestDtype>() | ||
.forward_mut(dev.zeros::<Rank3<1, 8, 8>>()); | ||
} | ||
|
||
#[test] | ||
fn test_conv_with_optimizer() { | ||
let dev: TestDevice = Default::default(); | ||
|
||
let mut m = dev.build_module::<ConvTrans2D<2, 4, 3>, TestDtype>(); | ||
|
||
let weight_init = m.weight.clone(); | ||
|
||
let mut opt = Sgd::new(&m, Default::default()); | ||
let out = m.forward(dev.sample_normal::<Rank4<8, 2, 28, 28>>().leaky_trace()); | ||
let g = out.square().mean().backward(); | ||
|
||
assert_ne!(g.get(&m.weight).array(), [[[[0.0; 3]; 3]; 2]; 4]); | ||
|
||
opt.update(&mut m, &g).expect("unused params"); | ||
|
||
assert_ne!(weight_init.array(), m.weight.array()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#[cfg(feature = "nightly")] | ||
use crate::prelude::{Const, Dim, Dtype, HasErr, Tape, Tensor, Upscale2DKernel, ZerosTensor}; | ||
use crate::prelude::{ConstUpscale2D, NearestNeighbor, UpscaleMethod}; | ||
|
||
#[allow(unused)] | ||
use super::{BuildModule, Module, NonMutableModule, ZeroSizedModule}; | ||
|
||
#[derive(Debug, Default, Clone)] | ||
pub struct Upscale2D<const OH: usize, const OW: usize = OH, M: UpscaleMethod = NearestNeighbor>(M); | ||
|
||
impl<const OH: usize, const OW: usize, M: UpscaleMethod> ZeroSizedModule for Upscale2D<OH, OW, M> {} | ||
impl<const OH: usize, const OW: usize, M: UpscaleMethod> NonMutableModule for Upscale2D<OH, OW, M> {} | ||
|
||
impl<const OH: usize, const OW: usize, M: UpscaleMethod, Img: ConstUpscale2D<M>> Module<Img> | ||
for Upscale2D<OH, OW, M> | ||
{ | ||
type Output = Img::Output<OH, OW>; | ||
type Error = Img::Err; | ||
|
||
fn try_forward(&self, x: Img) -> Result<Self::Output, Img::Err> { | ||
x.try_upscale2d() | ||
} | ||
} | ||
|
||
#[cfg(feature = "nightly")] | ||
#[derive(Debug, Default, Clone)] | ||
pub struct Upscale2DBy<const H: usize, const W: usize = H, M: UpscaleMethod = NearestNeighbor>(M); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
|
||
#[cfg(feature = "nightly")] | ||
impl<const H: usize, const W: usize, M: UpscaleMethod> ZeroSizedModule for Upscale2DBy<H, W, M> {} | ||
#[cfg(feature = "nightly")] | ||
impl<const H: usize, const W: usize, M: UpscaleMethod> NonMutableModule for Upscale2DBy<H, W, M> {} | ||
|
||
#[cfg(feature = "nightly")] | ||
impl< | ||
const H: usize, | ||
const W: usize, | ||
const IH: usize, | ||
const IW: usize, | ||
C: Dim, | ||
E: Dtype, | ||
M: UpscaleMethod, | ||
D: Upscale2DKernel<E, M> + ZerosTensor<E>, | ||
T: 'static + Tape<E, D>, | ||
> Module<Tensor<(C, Const<IH>, Const<IW>), E, D, T>> for Upscale2DBy<H, W, M> | ||
where | ||
Tensor<(C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>: Sized, | ||
{ | ||
type Output = Tensor<(C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>; | ||
type Error = <Self::Output as HasErr>::Err; | ||
|
||
fn try_forward( | ||
&self, | ||
x: Tensor<(C, Const<IH>, Const<IW>), E, D, T>, | ||
) -> Result<Self::Output, Self::Error> { | ||
x.try_upscale2d() | ||
} | ||
} | ||
|
||
#[cfg(feature = "nightly")] | ||
impl< | ||
const H: usize, | ||
const W: usize, | ||
const IH: usize, | ||
const IW: usize, | ||
B: Dim, | ||
C: Dim, | ||
E: Dtype, | ||
M: UpscaleMethod, | ||
D: Upscale2DKernel<E, M> + ZerosTensor<E>, | ||
T: 'static + Tape<E, D>, | ||
> Module<Tensor<(B, C, Const<IH>, Const<IW>), E, D, T>> for Upscale2DBy<H, W, M> | ||
where | ||
Tensor<(B, C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>: Sized, | ||
{ | ||
type Output = Tensor<(B, C, Const<{ IH * H }>, Const<{ IW * W }>), E, D, T>; | ||
type Error = <Self::Output as HasErr>::Err; | ||
|
||
fn try_forward( | ||
&self, | ||
x: Tensor<(B, C, Const<IH>, Const<IW>), E, D, T>, | ||
) -> Result<Self::Output, Self::Error> { | ||
x.try_upscale2d() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::{prelude::Bilinear, shapes::*, tensor::*, tests::*}; | ||
|
||
#[test] | ||
fn test_upscale2d() { | ||
let dev: TestDevice = Default::default(); | ||
let x: Tensor<Rank3<3, 4, 4>, TestDtype, _> = dev.zeros(); | ||
let _: Tensor<Rank3<3, 8, 8>, _, _> = Upscale2D::<8>::default().forward(x.clone()); | ||
let _: Tensor<Rank3<3, 8, 12>, _, _> = Upscale2D::<8, 12>::default().forward(x.clone()); | ||
let _: Tensor<Rank3<3, 9, 9>, _, _> = | ||
Upscale2D::<9, 9, NearestNeighbor>::default().forward(x.clone()); | ||
} | ||
|
||
#[cfg(feature = "nightly")] | ||
#[test] | ||
fn test_upscale2dby() { | ||
let dev: TestDevice = Default::default(); | ||
let x: Tensor<Rank3<3, 4, 4>, TestDtype, _> = dev.zeros(); | ||
let _: Tensor<Rank3<3, 8, 8>, _, _> = Upscale2DBy::<2>::default().forward(x.clone()); | ||
let _: Tensor<Rank3<3, 8, 12>, _, _> = Upscale2DBy::<2, 3>::default().forward(x.clone()); | ||
let _: Tensor<Rank3<3, 12, 12>, _, _> = | ||
Upscale2DBy::<3, 3, Bilinear>::default().forward(x.clone()); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to add to the builders mod below as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this for ConvTrans2D, but neither upscale has a builder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah, but currently builders re-exports modules if there's no existing builder for it. The intent is someone can either do
use dfdx::modules::*
ordfdx::builders::*
and get everything either way.