From c28a0eec2ac300744875fe4eb1e669c083fc5d63 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 18 Aug 2022 07:58:26 -0400 Subject: [PATCH] Adding min_axis function --- src/tensor_ops/impl_max_axis.rs | 16 ++---- src/tensor_ops/impl_min_axis.rs | 98 +++++++++++++++++++++++++++++++++ src/tensor_ops/mod.rs | 2 + 3 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 src/tensor_ops/impl_min_axis.rs diff --git a/src/tensor_ops/impl_max_axis.rs b/src/tensor_ops/impl_max_axis.rs index 6cfda1d06..3b3aae92e 100644 --- a/src/tensor_ops/impl_max_axis.rs +++ b/src/tensor_ops/impl_max_axis.rs @@ -80,13 +80,10 @@ mod tests { let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 2.0], [3.0, -2.0, 2.0]]); let r = t.trace().max_axis::<0>(); assert_eq!(r.data(), &[3.0, 2.0, 2.0]); - let gradients = r.exp().mean().backward(); + let g = r.exp().mean().backward(); assert_eq!( - gradients.ref_gradient(&t), - &[ - [0.00000000, 2.463019, 2.463019], - [6.695179, 0.00000000, 2.463019] - ] + g.ref_gradient(&t), + &[[0.0, 2.463019, 2.463019], [6.695179, 0.0, 2.463019]] ); } @@ -95,10 +92,7 @@ mod tests { let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 2.0], [3.0, -2.0, 2.0]]); let r = t.trace().max_axis::<-1>(); assert_eq!(r.data(), &[2.0, 3.0]); - let gradients = r.exp().mean().backward(); - assert_eq!( - gradients.ref_gradient(&t), - &[[0.0, 3.694528, 3.694528], [10.0427685, 0.0, 0.0]] - ); + let g = r.sum().backward(); + assert_eq!(g.ref_gradient(&t), &[[0.0, 1.0, 1.0], [1.0, 0.0, 0.0]]); } } diff --git a/src/tensor_ops/impl_min_axis.rs b/src/tensor_ops/impl_min_axis.rs new file mode 100644 index 000000000..9181b9277 --- /dev/null +++ b/src/tensor_ops/impl_min_axis.rs @@ -0,0 +1,98 @@ +use super::utils::move_tape_and_add_backward_op; +use crate::prelude::*; + +/// Reduces dimension `I` of the tensor by gathering the minimum value from that dimension. +/// +/// **Pytorch equivalent**: `t.amin(I)` +/// +/// **NOTE** This evenly distributes gradients between all equal minimum values, instead +/// of only exactly 1 value. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// let t = Tensor2D::new([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]); +/// let r: Tensor1D<2> = t.min_axis::<-1>(); +/// assert_eq!(r.data(), &[1.0, -3.0]); +/// ``` +pub fn min_axis, const I: isize>(mut t: T) -> T::Reduced { + let mut result = ::NoTape::zeros(); + T::DeviceR::reduce_into(t.data(), result.mut_data(), f32::min); + + // store derivative in t + T::DeviceR::foreach_br(t.mut_data(), result.data(), &mut |l, r| { + *l = if l == r { 1.0 } else { 0.0 } + }); + + move_tape_and_add_backward_op(t, result, move |mut t, result, grads| { + let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); + + T::DeviceR::foreach_br(t.mut_data(), result_grad, &mut |d, r| { + *d *= r; + }); + T::Device::add(t_grad, t.data()); + }) +} + +macro_rules! min_axis_impl { + ($typename:ident, [$($Vs:tt),*]) => { +impl<$(const $Vs: usize, )* H: Tape> $typename<$($Vs, )* H> { + /// Calls [min_axis()] on `self`. + pub fn min_axis(self) -> >::Reduced + where + Self: Reduce1, + { + min_axis::(self) + } +} + }; +} + +min_axis_impl!(Tensor0D, []); +min_axis_impl!(Tensor1D, [M]); +min_axis_impl!(Tensor2D, [M, N]); +min_axis_impl!(Tensor3D, [M, N, O]); +min_axis_impl!(Tensor4D, [M, N, O, P]); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valids_min_axis() { + let _: Tensor0D = Tensor1D::<5>::zeros().min_axis::<-1>(); + + let _: Tensor1D<3> = Tensor2D::<5, 3>::zeros().min_axis::<0>(); + let _: Tensor1D<5> = Tensor2D::<5, 3>::zeros().min_axis::<-1>(); + + let _: Tensor2D<5, 3> = Tensor3D::<7, 5, 3>::zeros().min_axis::<0>(); + let _: Tensor2D<7, 3> = Tensor3D::<7, 5, 3>::zeros().min_axis::<1>(); + let _: Tensor2D<7, 5> = Tensor3D::<7, 5, 3>::zeros().min_axis::<-1>(); + + let _: Tensor3D<7, 5, 3> = Tensor4D::<9, 7, 5, 3>::zeros().min_axis::<0>(); + let _: Tensor3D<9, 5, 3> = Tensor4D::<9, 7, 5, 3>::zeros().min_axis::<1>(); + let _: Tensor3D<9, 7, 3> = Tensor4D::<9, 7, 5, 3>::zeros().min_axis::<2>(); + let _: Tensor3D<9, 7, 5> = Tensor4D::<9, 7, 5, 3>::zeros().min_axis::<-1>(); + } + + #[test] + fn test_min_axis_0_2d() { + let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 1.0, 2.0], [3.0, -2.0, 2.0]]); + let r = t.trace().min_axis::<0>(); + assert_eq!(r.data(), &[1.0, -2.0, 2.0]); + let g = r.exp().mean().backward(); + assert_eq!( + g.ref_gradient(&t), + &[[0.90609396, 0.0, 2.463019], [0.0, 0.04511176, 2.463019]] + ); + } + + #[test] + fn test_min_axis_1_2d() { + let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 1.0, 2.0], [3.0, -2.0, 2.0]]); + let r = t.trace().min_axis::<-1>(); + assert_eq!(r.data(), &[1.0, -2.0]); + let g = r.sum().backward(); + assert_eq!(g.ref_gradient(&t), &[[1.0, 1.0, 0.0], [0.0, 1.0, 0.0]]); + } +} diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 7f0b42317..fe9bffb00 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -102,6 +102,7 @@ mod impl_mask; mod impl_max_axis; mod impl_mean; mod impl_mean_axis; +mod impl_min_axis; mod impl_nans; mod impl_normalize_axis; mod impl_softmax; @@ -124,6 +125,7 @@ pub use impl_mask::*; pub use impl_max_axis::*; pub use impl_mean::*; pub use impl_mean_axis::*; +pub use impl_min_axis::*; pub use impl_nans::*; pub use impl_normalize_axis::*; pub use impl_softmax::*;