Skip to content

Commit

Permalink
Flatten2D now accepts generic batch dim (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Apr 20, 2023
1 parent a48743b commit 10427a9
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/nn/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@ where
}

#[cfg(feature = "nightly")]
impl<const B: usize, const C: usize, const H: usize, const W: usize, D, E: Dtype, T>
Module<Tensor<Rank4<B, C, H, W>, E, D, T>> for Flatten2D
impl<B: Dim, const C: usize, const H: usize, const W: usize, D, E: Dtype, T>
Module<Tensor<(B, Const<C>, Const<H>, Const<W>), E, D, T>> for Flatten2D
where
D: Device<E>,
T: Tape<E, D>,
Rank2<B, { C * H * W }>: Sized,
(B, Const<{ C * H * W }>): Sized,
{
type Output = Tensor<Rank2<B, { C * H * W }>, E, D, T>;
type Output = Tensor<(B, Const<{ C * H * W }>), E, D, T>;
type Error = D::Err;

fn try_forward(
&self,
input: Tensor<Rank4<B, C, H, W>, E, D, T>,
input: Tensor<(B, Const<C>, Const<H>, Const<W>), E, D, T>,
) -> Result<Self::Output, D::Err> {
input.try_reshape()
let batch = input.shape.0;
input.try_reshape_like(&(batch, Const)).unwrap()
}
}

Expand All @@ -60,5 +61,8 @@ mod tests {
Flatten2D.forward_mut(dev.zeros::<Rank3<10, 5, 2>>());
let _: Tensor<Rank2<5, 24>, TestDtype, _> =
Flatten2D.forward_mut(dev.zeros::<Rank4<5, 4, 3, 2>>());
let x = dev.zeros_like(&(5, Const::<4>, Const::<3>, Const::<2>));
let y: Tensor<(usize, Const<24>), TestDtype, _> = Flatten2D.forward_mut(x);
assert_eq!(y.shape.0, 5);
}
}

0 comments on commit 10427a9

Please sign in to comment.