From 669ba517aa6547d053a27c3db9b3208ad72dae5c Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 8 Sep 2022 07:19:31 -0400 Subject: [PATCH 01/14] device select now uses traits --- src/devices/select.rs | 97 ++++++++++++++++++------------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index 2adf6f739..e73a2d172 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -28,80 +28,63 @@ pub trait SelectAlongAxis { -impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, usize, $DstTy, $Axis> for Cpu { - fn select_axis(inp: &$SrcTy, indices: &usize, out: &mut $DstTy) { +impl SelectAlongAxis<[T; M], usize, T, 0> for Cpu +where + Self: ForEachElement, + T: Copy + CountElements, + T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, +{ + fn select_axis(inp: &[T; M], indices: &usize, out: &mut T) { *out = inp[*indices]; } - fn select_add(inp: &mut $SrcTy, indices: &usize, out: &$DstTy) { + fn select_add(inp: &mut [T; M], indices: &usize, out: &T) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } } - }; -} -macro_rules! select_0z { - ($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => { -impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, [usize; Z], $DstTy, $Axis> for Cpu { - fn select_axis(inp: &$SrcTy, indices: &[usize; Z], out: &mut $DstTy) { +impl SelectAlongAxis<[T; M], [usize; Z], [T; Z], 0> for Cpu +where + Self: ForEachElement, + T: Copy + CountElements, + T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, +{ + fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut [T; Z]) { for z in 0..Z { out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut $SrcTy, indices: &[usize; Z], out: &$DstTy) { + fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &[T; Z]) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); } } } - }; -} macro_rules! select_nz { - ($Axis:expr, $SrcTy:tt, $IndTy:tt, $DstTy:tt, {$($Dims:tt),*}) => { -impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, $IndTy, $DstTy, $Axis> for Cpu { - fn select_axis(inp: &$SrcTy, indices: &$IndTy, out: &mut $DstTy) { - for m in 0..M { - Self::select_axis(&inp[m], &indices[m], &mut out[m]); + ($Axis:expr, $SubAxis:expr) => { + impl SelectAlongAxis<[T; M], [I; M], [R; M], $Axis> for Cpu + where + Self: SelectAlongAxis, + T: CountElements, + R: CountElements, + { + fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut [R; M]) { + for m in 0..M { + Self::select_axis(&inp[m], &indices[m], &mut out[m]); + } + } + fn select_add(inp: &mut [T; M], indices: &[I; M], out: &[R; M]) { + for m in 0..M { + Self::select_add(&mut inp[m], &indices[m], &out[m]); + } + } } - } - fn select_add(inp: &mut $SrcTy, indices: &$IndTy, out: &$DstTy) { - for m in 0..M { - Self::select_add(&mut inp[m], &indices[m], &out[m]); - } - } -} }; } -// 1d -select_01!(-1, [f32; M], f32, { M }); -select_0z!(-1, [f32; M], [f32; Z], {M, Z}); - -// 2d -select_01!(0, [[f32; N]; M], [f32; N], {M, N}); -select_0z!(0, [[f32; N]; M], [[f32; N]; Z], {M, N, Z}); -select_nz!(-1, [[f32; N]; M], [usize; M], [f32; M], {M, N}); -select_nz!(-1, [[f32; N]; M], [[usize; Z]; M], [[f32; Z]; M], {M, N, Z}); - -// 3d -select_01!(0, [[[f32; O]; N]; M], [[f32; O]; N], {M, N, O}); -select_0z!(0, [[[f32; O]; N]; M], [[[f32; O]; N]; Z], {M, N, O, Z}); -select_nz!(1, [[[f32; O]; N]; M], [usize; M], [[f32; O]; M], {M, N, O}); -select_nz!(1, [[[f32; O]; N]; M], [[usize; Z]; M], [[[f32; O]; Z]; M], {M, N, O, Z}); -select_nz!(-1, [[[f32; O]; N]; M], [[usize; N]; M], [[f32; N]; M], {M, N, O}); -select_nz!(-1, [[[f32; O]; N]; M], [[[usize; Z]; N]; M], [[[f32; Z]; N]; M], {M, N, O, Z}); - -// 4d -select_01!(0, [[[[f32; P]; O]; N]; M], [[[f32; P]; O]; N], {M, N, O, P}); -select_0z!(0, [[[[f32; P]; O]; N]; M], [[[[f32; P]; O]; N]; Z], {M, N, O, P, Z}); -select_nz!(1, [[[[f32; P]; O]; N]; M], [usize; M], [[[f32; P]; O]; M], {M, N, O, P}); -select_nz!(1, [[[[f32; P]; O]; N]; M], [[usize; Z]; M], [[[[f32; P]; O]; Z]; M], {M, N, O, P, Z}); -select_nz!(2, [[[[f32; P]; O]; N]; M], [[usize; N]; M], [[[f32; P]; N]; M], {M, N, O, P}); -select_nz!(2, [[[[f32; P]; O]; N]; M], [[[usize; Z]; N]; M], [[[[f32; P]; Z]; N]; M], {M, N, O, P, Z}); -select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[usize; O]; N]; M], [[[f32; O]; N]; M], {M, N, O, P}); -select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[[usize; Z]; O]; N]; M], [[[[f32; Z]; O]; N]; M], {M, N, O, P, Z}); +select_nz!(1, 0); +select_nz!(2, 1); +select_nz!(3, 2); #[cfg(test)] mod tests { @@ -112,7 +95,7 @@ mod tests { fn test_select_1d_0() { let a = [1.0, 2.0, 3.0]; let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &1, &mut b); + Cpu::select_axis(&a, &1usize, &mut b); assert_eq!(b, 2.0); } @@ -146,7 +129,7 @@ mod tests { fn test_select_2d_1() { let a = A_2D; let mut b = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 1], &mut b); + >::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -210,7 +193,7 @@ mod tests { fn test_select_3d_2() { let a = A_3D; let mut b = ZeroElements::ZEROS; - >::select_axis( + >::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], &mut b, @@ -230,7 +213,7 @@ mod tests { fn test_select_3d_2z() { let a = A_3D; let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS; - >::select_axis( + >::select_axis( &a, &[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]], &mut b, From 2297904fedd944fa7a91a3b6379d96f45117ddb4 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 8 Sep 2022 09:13:59 -0400 Subject: [PATCH 02/14] Rework devices/select. Adding Broadcasted select --- src/devices/select.rs | 106 +++++++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 32 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index e73a2d172..32bdb9d69 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -19,61 +19,74 @@ use super::{Cpu, ForEachElement}; use crate::arrays::CountElements; -/// Select values from `T` using `Indices` and producing `R` along a single `AXIS`. -pub trait SelectAlongAxis { +pub mod modes { + pub struct Index; + pub struct Recurse; + pub struct Broadcast; +} + +use modes::*; + +pub trait DeviceSelect { + type Result; + /// Equivalent to psuedocode `out = inp[indices]` - fn select_axis(inp: &T, indices: &Indices, out: &mut R); + fn select_axis(inp: &T, indices: &I, out: &mut Self::Result); /// `inp[indices] += out` - fn select_add(inp: &mut T, indices: &Indices, out: &R); + fn select_add(inp: &mut T, indices: &I, out: &Self::Result); } -impl SelectAlongAxis<[T; M], usize, T, 0> for Cpu +impl DeviceSelect<[T; M], usize, Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - fn select_axis(inp: &[T; M], indices: &usize, out: &mut T) { + type Result = T; + + fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) { *out = inp[*indices]; } - fn select_add(inp: &mut [T; M], indices: &usize, out: &T) { + fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } } -impl SelectAlongAxis<[T; M], [usize; Z], [T; Z], 0> for Cpu +impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut [T; Z]) { + type Result = [T; Z]; + fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) { for z in 0..Z { out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &[T; Z]) { + + fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); } } } -macro_rules! select_nz { - ($Axis:expr, $SubAxis:expr) => { - impl SelectAlongAxis<[T; M], [I; M], [R; M], $Axis> for Cpu +macro_rules! nd_recurse { + ($Mode:ty, $SubMode:ty) => { + impl DeviceSelect<[T; M], [I; M], $Mode> for Cpu where - Self: SelectAlongAxis, - T: CountElements, - R: CountElements, + Self: DeviceSelect, { - fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut [R; M]) { + type Result = [>::Result; M]; + + fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) { for m in 0..M { Self::select_axis(&inp[m], &indices[m], &mut out[m]); } } - fn select_add(inp: &mut [T; M], indices: &[I; M], out: &[R; M]) { + fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) { for m in 0..M { Self::select_add(&mut inp[m], &indices[m], &out[m]); } @@ -82,9 +95,28 @@ macro_rules! select_nz { }; } -select_nz!(1, 0); -select_nz!(2, 1); -select_nz!(3, 2); +nd_recurse!(Recurse<0>, Index); +nd_recurse!(Recurse<1>, Recurse<0>); +nd_recurse!(Recurse<2>, Recurse<1>); +nd_recurse!(Recurse<3>, Recurse<2>); + +impl DeviceSelect> for Cpu +where + Self: DeviceSelect, +{ + type Result = [>::Result; M]; + + fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) { + for m in 0..M { + Self::select_axis(inp, &indices[m], &mut out[m]); + } + } + fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) { + for m in 0..M { + Self::select_add(inp, &indices[m], &out[m]); + } + } +} #[cfg(test)] mod tests { @@ -101,9 +133,9 @@ mod tests { #[test] fn test_select_1d_0z() { - let a = [1.0, 2.0, 3.0]; + let a = [1.0f32, 2.0, 3.0]; let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); + >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]); } @@ -121,7 +153,7 @@ mod tests { fn test_select_2d_0z() { let a = A_2D; let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &[0, 0, 1], &mut b); + >::select_axis(&a, &[0, 0, 1], &mut b); assert_eq!(b, [a[0], a[0], a[1]]); } @@ -129,7 +161,7 @@ mod tests { fn test_select_2d_1() { let a = A_2D; let mut b = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 1], &mut b); + >>::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -137,16 +169,26 @@ mod tests { fn test_select_2d_1z() { let a = A_2D; let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &[[0, 2], [1, 1]], &mut b); + >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]); } + #[test] + fn test_select_broadcast_2d() { + let a = [[1.0], [2.0]]; + let i = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; + let mut b = ZeroElements::ZEROS; + >>::select_axis(&a, &i, &mut b); + #[rustfmt::skip] + assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]); + } + #[test] fn test_select_add_2d() { let mut a = [[0.0; 3]; 2]; let b = [[1.0, 3.0], [5.0, 5.0]]; let i = [[0, 2], [1, 1]]; - Cpu::select_add(&mut a, &i, &b); + >>::select_add(&mut a, &i, &b); assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]); } @@ -169,7 +211,7 @@ mod tests { fn test_select_3d_0z() { let a = A_3D; let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); + >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]); } @@ -177,7 +219,7 @@ mod tests { fn test_select_3d_1() { let a = A_3D; let mut b = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 0, 1, 1], &mut b); + >>::select_axis(&a, &[0, 0, 1, 1], &mut b); assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]); } @@ -185,7 +227,7 @@ mod tests { fn test_select_3d_1z() { let a = A_3D; let mut b = ZeroElements::ZEROS; - >::select_axis(&a, &[[0], [0], [1], [1]], &mut b); + >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]); } @@ -193,7 +235,7 @@ mod tests { fn test_select_3d_2() { let a = A_3D; let mut b = ZeroElements::ZEROS; - >::select_axis( + >>::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], &mut b, @@ -213,7 +255,7 @@ mod tests { fn test_select_3d_2z() { let a = A_3D; let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS; - >::select_axis( + >>::select_axis( &a, &[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]], &mut b, From 93f01993bb16bbec685ca6e4584c23be162ddac8 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 8 Sep 2022 09:25:46 -0400 Subject: [PATCH 03/14] Making DeviceSelect take Input & Output and index is associated --- src/devices/select.rs | 76 +++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index 32bdb9d69..ef8bd172e 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -27,46 +27,46 @@ pub mod modes { use modes::*; -pub trait DeviceSelect { - type Result; +pub trait DeviceSelect { + type Indices; /// Equivalent to psuedocode `out = inp[indices]` - fn select_axis(inp: &T, indices: &I, out: &mut Self::Result); + fn select_axis(inp: &T, indices: &Self::Indices, out: &mut R); /// `inp[indices] += out` - fn select_add(inp: &mut T, indices: &I, out: &Self::Result); + fn select_add(inp: &mut T, indices: &Self::Indices, out: &R); } -impl DeviceSelect<[T; M], usize, Index> for Cpu +impl DeviceSelect<[T; M], T, Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - type Result = T; + type Indices = usize; - fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) { + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut T) { *out = inp[*indices]; } - fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &T) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } } -impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu +impl DeviceSelect<[T; M], [T; Z], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - type Result = [T; Z]; - fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) { + type Indices = [usize; Z]; + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [T; Z]) { for z in 0..Z { out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[T; Z]) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); } @@ -75,18 +75,18 @@ where macro_rules! nd_recurse { ($Mode:ty, $SubMode:ty) => { - impl DeviceSelect<[T; M], [I; M], $Mode> for Cpu + impl DeviceSelect<[T; M], [R; M], $Mode> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Result = [>::Result; M]; + type Indices = [>::Indices; M]; - fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) { + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [R; M]) { for m in 0..M { Self::select_axis(&inp[m], &indices[m], &mut out[m]); } } - fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[R; M]) { for m in 0..M { Self::select_add(&mut inp[m], &indices[m], &out[m]); } @@ -100,18 +100,18 @@ nd_recurse!(Recurse<1>, Recurse<0>); nd_recurse!(Recurse<2>, Recurse<1>); nd_recurse!(Recurse<3>, Recurse<2>); -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Result = [>::Result; M]; + type Indices = [>::Indices; M]; - fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) { + fn select_axis(inp: &T, indices: &Self::Indices, out: &mut [R; M]) { for m in 0..M { Self::select_axis(inp, &indices[m], &mut out[m]); } } - fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) { + fn select_add(inp: &mut T, indices: &Self::Indices, out: &[R; M]) { for m in 0..M { Self::select_add(inp, &indices[m], &out[m]); } @@ -125,16 +125,16 @@ mod tests { #[test] fn test_select_1d_0() { - let a = [1.0, 2.0, 3.0]; - let mut b = ZeroElements::ZEROS; - Cpu::select_axis(&a, &1usize, &mut b); + let a: [f32; 3] = [1.0, 2.0, 3.0]; + let mut b: f32 = ZeroElements::ZEROS; + Cpu::select_axis(&a, &1, &mut b); assert_eq!(b, 2.0); } #[test] fn test_select_1d_0z() { - let a = [1.0f32, 2.0, 3.0]; - let mut b = ZeroElements::ZEROS; + let a: [f32; 3] = [1.0f32, 2.0, 3.0]; + let mut b: [f32; 6] = ZeroElements::ZEROS; >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]); } @@ -144,7 +144,7 @@ mod tests { #[test] fn test_select_2d_0() { let a = A_2D; - let mut b = ZeroElements::ZEROS; + let mut b: [f32; 3] = ZeroElements::ZEROS; Cpu::select_axis(&a, &0, &mut b); assert_eq!(b, [1.0, 2.0, 3.0]); } @@ -152,7 +152,7 @@ mod tests { #[test] fn test_select_2d_0z() { let a = A_2D; - let mut b = ZeroElements::ZEROS; + let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS; >::select_axis(&a, &[0, 0, 1], &mut b); assert_eq!(b, [a[0], a[0], a[1]]); } @@ -160,7 +160,7 @@ mod tests { #[test] fn test_select_2d_1() { let a = A_2D; - let mut b = ZeroElements::ZEROS; + let mut b: [f32; 2] = ZeroElements::ZEROS; >>::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -168,7 +168,7 @@ mod tests { #[test] fn test_select_2d_1z() { let a = A_2D; - let mut b = ZeroElements::ZEROS; + let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS; >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]); } @@ -176,8 +176,8 @@ mod tests { #[test] fn test_select_broadcast_2d() { let a = [[1.0], [2.0]]; - let i = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; - let mut b = ZeroElements::ZEROS; + let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; + let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS; >>::select_axis(&a, &i, &mut b); #[rustfmt::skip] assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]); @@ -202,7 +202,7 @@ mod tests { #[test] fn test_select_3d_0() { let a = A_3D; - let mut b = ZeroElements::ZEROS; + let mut b: [[f32; 3]; 2] = ZeroElements::ZEROS; Cpu::select_axis(&a, &0, &mut b); assert_eq!(b, A_3D[0]); } @@ -210,7 +210,7 @@ mod tests { #[test] fn test_select_3d_0z() { let a = A_3D; - let mut b = ZeroElements::ZEROS; + let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS; >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]); } @@ -218,7 +218,7 @@ mod tests { #[test] fn test_select_3d_1() { let a = A_3D; - let mut b = ZeroElements::ZEROS; + let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS; >>::select_axis(&a, &[0, 0, 1, 1], &mut b); assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]); } @@ -226,7 +226,7 @@ mod tests { #[test] fn test_select_3d_1z() { let a = A_3D; - let mut b = ZeroElements::ZEROS; + let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS; >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]); } @@ -234,7 +234,7 @@ mod tests { #[test] fn test_select_3d_2() { let a = A_3D; - let mut b = ZeroElements::ZEROS; + let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS; >>::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], From 9121ece32d4720065b06bc95cc4f289beba3fd95 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 8 Sep 2022 09:42:22 -0400 Subject: [PATCH 04/14] Temp commit of select function --- src/devices/select.rs | 6 +- src/tensor_ops/select.rs | 150 +++++++++++++++++++-------------------- 2 files changed, 75 insertions(+), 81 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index ef8bd172e..6524a5c26 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -19,16 +19,16 @@ use super::{Cpu, ForEachElement}; use crate::arrays::CountElements; -pub mod modes { +pub mod select_modes { pub struct Index; pub struct Recurse; pub struct Broadcast; } -use modes::*; +use select_modes::*; pub trait DeviceSelect { - type Indices; + type Indices: Clone; /// Equivalent to psuedocode `out = inp[indices]` fn select_axis(inp: &T, indices: &Self::Indices, out: &mut R); diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index dd92cad0a..1f6c316c7 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -10,94 +10,88 @@ use crate::prelude::*; /// 2. Select multiple values from an axis, which keeps the number /// of dimensions the same. You can select the same element multiple /// number of times. -pub trait Select1 { - type Indices: Clone; - - /// Select sub elements using [Self::Indices]. - /// The same element can be selected multiple times depending - /// on [Self::Indices]. - /// - /// Selecting single value from 2d tensors: - /// ```rust - /// # use dfdx::prelude::*; - /// // select a single element from the 0th axis - /// let _: Tensor1D<5> = Tensor2D::<3, 5>::zeros().select(&0); - /// - /// // select a single element from the 1st axis - number of elements is equal - /// // to the size of the 0th axis, and the usize values can be 0..5 - /// let _: Tensor1D<3> = Tensor2D::<3, 5>::zeros().select(&[0, 2, 4]); - ///``` - /// - /// Selecting multiple values from 2d tensors: - /// ```rust - /// # use dfdx::prelude::*; - /// // select a multiple elements from the 0th axis. - /// // the number of indices is the new size of the 0th axis. - /// let _: Tensor2D<6, 5> = Tensor2D::<3, 5>::zeros().select(&[0, 1, 2, 0, 1, 2]); - /// - /// // select a multiple elements from the 1st axis. - /// // must have same number of elements as the 0th axis, and the number of indices - /// // is the new size of the 1st axis. - /// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); - /// ``` - fn select(self, indices: &Self::Indices) -> T; +/// +/// Select sub elements using [Self::Indices]. +/// The same element can be selected multiple times depending +/// on [Self::Indices]. +/// +/// Selecting single value from 2d tensors: +/// ```rust +/// # use dfdx::prelude::*; +/// // select a single element from the 0th axis +/// let _: Tensor1D<5> = Tensor2D::<3, 5>::zeros().select(&0); +/// +/// // select a single element from the 1st axis - number of elements is equal +/// // to the size of the 0th axis, and the usize values can be 0..5 +/// let _: Tensor1D<3> = Tensor2D::<3, 5>::zeros().select(&[0, 2, 4]); +///``` +/// +/// Selecting multiple values from 2d tensors: +/// ```rust +/// # use dfdx::prelude::*; +/// // select a multiple elements from the 0th axis. +/// // the number of indices is the new size of the 0th axis. +/// let _: Tensor2D<6, 5> = Tensor2D::<3, 5>::zeros().select(&[0, 1, 2, 0, 1, 2]); +/// +/// // select a multiple elements from the 1st axis. +/// // must have same number of elements as the 0th axis, and the number of indices +/// // is the new size of the 1st axis. +/// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); +/// ``` +pub fn select( + t: T, + indices: &<::Device as DeviceSelect>::Indices, +) -> R +where + T: Tensor, + R: Tensor, + ::Device: DeviceSelect, + <::Device as DeviceSelect>::Indices: 'static, +{ + let mut result: ::NoTape = TensorCreator::zeros(); + ::Device::select_axis(t.data(), indices, result.mut_data()); + + #[allow(clippy::clone_on_copy)] + let i = indices.clone(); + + move_tape_and_add_backward_op(t, result, move |mut t, result, grads| { + let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); + ::Device::fill(t.mut_data(), &mut |v| *v = 0.0); + ::Device::select_add(t.mut_data(), &i, result_grad); + ::Device::add(t_grad, t.data()); + }) } -macro_rules! impl_select { - ($Axis:expr, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { -impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy { - type Indices = $IndTy; - fn select(self, indices: &Self::Indices) -> $DstTy { - let mut result: <$DstTy as Tensor>::NoTape = TensorCreator::zeros(); - Cpu::select_axis(self.data(), indices, result.mut_data()); - - #[allow(clippy::clone_on_copy)] - let i = indices.clone(); - - move_tape_and_add_backward_op(self, result, move |mut t, result, grads| { - let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); - Cpu::fill(t.mut_data(), &mut |v| *v = 0.0); - Cpu::select_add(t.mut_data(), &i, result_grad); - Cpu::add(t_grad, t.data()); - }) +macro_rules! tensor_impl { + ($typename:ident, [$($Vs:tt),*]) => { +impl<$(const $Vs: usize, )* H: Tape> $typename<$($Vs, )* H> { + /// Calls [select()] on `self`. + pub fn select( + self, + indices: &<::Device as DeviceSelect<::Array, R::Array, Mode>>::Indices + ) -> R + where + Self: Tensor, + R: Tensor::Tape>, + ::Device: DeviceSelect<::Array, R::Array, Mode>, + <::Device as DeviceSelect<::Array, R::Array, Mode>>::Indices: 'static, + { + select(self, indices) } } }; } -// 1d -impl_select!(-1, Tensor1D, usize, Tensor0D, {M}); -impl_select!(-1, Tensor1D, [usize; Z], Tensor1D, {M, Z}); - -// 2d -impl_select!(0, Tensor2D, usize, Tensor1D, {M, N}); -impl_select!(0, Tensor2D, [usize; Z], Tensor2D, {M, N, Z}); -impl_select!(-1, Tensor2D, [usize; M], Tensor1D, {M, N}); -impl_select!(-1, Tensor2D, [[usize; Z]; M], Tensor2D, {M, N, Z}); - -// 3d -impl_select!(0, Tensor3D, usize, Tensor2D, {M, N, O}); -impl_select!(0, Tensor3D, [usize; Z], Tensor3D, {M, N, O, Z}); -impl_select!(1, Tensor3D, [usize; M], Tensor2D, {M, N, O}); -impl_select!(1, Tensor3D, [[usize; Z]; M], Tensor3D, {M, N, O, Z}); -impl_select!(-1, Tensor3D, [[usize; N]; M], Tensor2D, {M, N, O}); -impl_select!(-1, Tensor3D, [[[usize; Z]; N]; M], Tensor3D, {M, N, O, Z}); - -// 4d -impl_select!(0, Tensor4D, usize, Tensor3D, {M, N, O, P}); -impl_select!(0, Tensor4D, [usize; Z], Tensor4D, {M, N, O, P, Z}); -impl_select!(1, Tensor4D, [usize; M], Tensor3D, {M, N, O, P}); -impl_select!(1, Tensor4D, [[usize; Z]; M], Tensor4D, {M, N, O, P, Z}); -impl_select!(2, Tensor4D, [[usize; N]; M], Tensor3D, {M, N, O, P}); -impl_select!(2, Tensor4D, [[[usize; Z]; N]; M], Tensor4D, {M, N, O, P, Z}); -impl_select!(-1, Tensor4D, [[[usize; O]; N]; M], Tensor3D, {M, N, O, P}); -impl_select!(-1, Tensor4D, [[[[usize; Z]; O]; N]; M], Tensor4D, {M, N, O, P, Z}); +tensor_impl!(Tensor0D, []); +tensor_impl!(Tensor1D, [M]); +tensor_impl!(Tensor2D, [M, N]); +tensor_impl!(Tensor3D, [M, N, O]); +tensor_impl!(Tensor4D, [M, N, O, P]); #[cfg(test)] mod tests { - use rand::thread_rng; - use super::*; + use rand::thread_rng; #[test] fn test_valid_selects_1d() { @@ -120,7 +114,7 @@ mod tests { fn test_select_1d_less_backward() { let mut rng = thread_rng(); let t: Tensor1D<5> = TensorCreator::randn(&mut rng); - let r: Tensor1D<2, OwnedTape> = t.trace().select(&[0, 3]); + let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, select_modes::Index>(&[0, 3]); assert_eq!(r.data(), &[t.data()[0], t.data()[3]]); let g = r.mean().backward(); assert_eq!(g.ref_gradient(&t), &[0.5, 0.0, 0.0, 0.5, 0.0]); From b5828aecc4001e4988ceeb3678f959a55de2c037 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 07:33:32 -0400 Subject: [PATCH 05/14] Improving modes of select --- src/devices/select.rs | 73 ++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index 6524a5c26..c73587ab1 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -18,14 +18,11 @@ use super::{Cpu, ForEachElement}; use crate::arrays::CountElements; +use std::marker::PhantomData; -pub mod select_modes { - pub struct Index; - pub struct Recurse; - pub struct Broadcast; -} - -use select_modes::*; +pub(crate) struct Index; +pub(crate) struct Recurse(PhantomData<*const M>); +pub(crate) struct Broadcast(PhantomData<*const M>); pub trait DeviceSelect { type Indices: Clone; @@ -48,6 +45,7 @@ where fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut T) { *out = inp[*indices]; } + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &T) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } @@ -65,7 +63,6 @@ where out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[T; Z]) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); @@ -73,38 +70,30 @@ where } } -macro_rules! nd_recurse { - ($Mode:ty, $SubMode:ty) => { - impl DeviceSelect<[T; M], [R; M], $Mode> for Cpu - where - Self: DeviceSelect, - { - type Indices = [>::Indices; M]; +impl DeviceSelect<[T; M], [R; M], Recurse> for Cpu +where + Self: DeviceSelect, +{ + type Indices = [>::Indices; M]; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [R; M]) { - for m in 0..M { - Self::select_axis(&inp[m], &indices[m], &mut out[m]); - } - } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[R; M]) { - for m in 0..M { - Self::select_add(&mut inp[m], &indices[m], &out[m]); - } - } + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [R; M]) { + for m in 0..M { + Self::select_axis(&inp[m], &indices[m], &mut out[m]); } - }; -} + } -nd_recurse!(Recurse<0>, Index); -nd_recurse!(Recurse<1>, Recurse<0>); -nd_recurse!(Recurse<2>, Recurse<1>); -nd_recurse!(Recurse<3>, Recurse<2>); + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[R; M]) { + for m in 0..M { + Self::select_add(&mut inp[m], &indices[m], &out[m]); + } + } +} -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Indices = [>::Indices; M]; + type Indices = [>::Indices; M]; fn select_axis(inp: &T, indices: &Self::Indices, out: &mut [R; M]) { for m in 0..M { @@ -161,7 +150,7 @@ mod tests { fn test_select_2d_1() { let a = A_2D; let mut b: [f32; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 1], &mut b); + >>::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -169,7 +158,7 @@ mod tests { fn test_select_2d_1z() { let a = A_2D; let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); + >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]); } @@ -178,7 +167,7 @@ mod tests { let a = [[1.0], [2.0]]; let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &i, &mut b); + >>::select_axis(&a, &i, &mut b); #[rustfmt::skip] assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]); } @@ -188,7 +177,7 @@ mod tests { let mut a = [[0.0; 3]; 2]; let b = [[1.0, 3.0], [5.0, 5.0]]; let i = [[0, 2], [1, 1]]; - >>::select_add(&mut a, &i, &b); + >>::select_add(&mut a, &i, &b); assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]); } @@ -219,7 +208,7 @@ mod tests { fn test_select_3d_1() { let a = A_3D; let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 0, 1, 1], &mut b); + >>::select_axis(&a, &[0, 0, 1, 1], &mut b); assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]); } @@ -227,7 +216,7 @@ mod tests { fn test_select_3d_1z() { let a = A_3D; let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); + >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]); } @@ -235,7 +224,7 @@ mod tests { fn test_select_3d_2() { let a = A_3D; let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS; - >>::select_axis( + >>>::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], &mut b, @@ -255,7 +244,7 @@ mod tests { fn test_select_3d_2z() { let a = A_3D; let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS; - >>::select_axis( + >>>::select_axis( &a, &[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]], &mut b, From d2f811766198e3cc50338614337907bbb4640022 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 07:36:40 -0400 Subject: [PATCH 06/14] Removing R from DeviceSelect --- src/devices/select.rs | 44 ++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index c73587ab1..1c369fc99 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -24,83 +24,89 @@ pub(crate) struct Index; pub(crate) struct Recurse(PhantomData<*const M>); pub(crate) struct Broadcast(PhantomData<*const M>); -pub trait DeviceSelect { +pub trait DeviceSelect { type Indices: Clone; + type Result; /// Equivalent to psuedocode `out = inp[indices]` - fn select_axis(inp: &T, indices: &Self::Indices, out: &mut R); + fn select_axis(inp: &T, indices: &Self::Indices, out: &mut Self::Result); /// `inp[indices] += out` - fn select_add(inp: &mut T, indices: &Self::Indices, out: &R); + fn select_add(inp: &mut T, indices: &Self::Indices, out: &Self::Result); } -impl DeviceSelect<[T; M], T, Index> for Cpu +impl DeviceSelect<[T; M], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { type Indices = usize; + type Result = T; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut T) { + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { *out = inp[*indices]; } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &T) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } } -impl DeviceSelect<[T; M], [T; Z], Index> for Cpu +impl DeviceSelect<[T; M], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { type Indices = [usize; Z]; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [T; Z]) { + type Result = [T; Z]; + + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { for z in 0..Z { out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[T; Z]) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); } } } -impl DeviceSelect<[T; M], [R; M], Recurse> for Cpu +impl DeviceSelect<[T; M], Recurse> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Indices = [>::Indices; M]; + type Indices = [>::Indices; M]; + type Result = [>::Result; M]; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut [R; M]) { + fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { for m in 0..M { Self::select_axis(&inp[m], &indices[m], &mut out[m]); } } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &[R; M]) { + fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { for m in 0..M { Self::select_add(&mut inp[m], &indices[m], &out[m]); } } } -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Indices = [>::Indices; M]; + type Indices = [>::Indices; M]; + type Result = [>::Result; M]; - fn select_axis(inp: &T, indices: &Self::Indices, out: &mut [R; M]) { + fn select_axis(inp: &T, indices: &Self::Indices, out: &mut Self::Result) { for m in 0..M { Self::select_axis(inp, &indices[m], &mut out[m]); } } - fn select_add(inp: &mut T, indices: &Self::Indices, out: &[R; M]) { + fn select_add(inp: &mut T, indices: &Self::Indices, out: &Self::Result) { for m in 0..M { Self::select_add(inp, &indices[m], &out[m]); } From 27f808657ff6f01f92c26c563ddab9baf2634470 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 07:47:06 -0400 Subject: [PATCH 07/14] tensor_ops select working --- src/devices/mod.rs | 5 ++++ src/devices/select.rs | 50 ++++++++++++++++------------------------ src/tensor_ops/select.rs | 33 ++++++++++++-------------- 3 files changed, 40 insertions(+), 48 deletions(-) diff --git a/src/devices/mod.rs b/src/devices/mod.rs index 4dfd9596b..f970098a7 100644 --- a/src/devices/mod.rs +++ b/src/devices/mod.rs @@ -20,8 +20,13 @@ pub use reduce_all::*; pub use reduce_axis::*; pub use select::*; +use std::marker::PhantomData; use std::ops::*; +pub struct Index; +pub struct Recurse(PhantomData<*const M>); +pub struct Broadcast(PhantomData<*const M>); + /// The CPU device pub struct Cpu; diff --git a/src/devices/select.rs b/src/devices/select.rs index 1c369fc99..1f1badd06 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -16,97 +16,87 @@ //! Then all three arrays with have the same dimension as the 0th axis. //! Do a for loop over the 0th axis and recurse! -use super::{Cpu, ForEachElement}; +use super::{Broadcast, Cpu, ForEachElement, Index, Recurse}; use crate::arrays::CountElements; -use std::marker::PhantomData; -pub(crate) struct Index; -pub(crate) struct Recurse(PhantomData<*const M>); -pub(crate) struct Broadcast(PhantomData<*const M>); - -pub trait DeviceSelect { - type Indices: Clone; +pub trait DeviceSelect { type Result; /// Equivalent to psuedocode `out = inp[indices]` - fn select_axis(inp: &T, indices: &Self::Indices, out: &mut Self::Result); + fn select_axis(inp: &T, indices: &I, out: &mut Self::Result); /// `inp[indices] += out` - fn select_add(inp: &mut T, indices: &Self::Indices, out: &Self::Result); + fn select_add(inp: &mut T, indices: &I, out: &Self::Result); } -impl DeviceSelect<[T; M], Index> for Cpu +impl DeviceSelect<[T; M], usize, Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - type Indices = usize; type Result = T; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { + fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) { *out = inp[*indices]; } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) { Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b); } } -impl DeviceSelect<[T; M], Index> for Cpu +impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>, { - type Indices = [usize; Z]; type Result = [T; Z]; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { + fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) { for z in 0..Z { out[z] = inp[indices[z]]; } } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) { for z in 0..Z { Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b); } } } -impl DeviceSelect<[T; M], Recurse> for Cpu +impl DeviceSelect<[T; M], [I; M], Recurse> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Indices = [>::Indices; M]; - type Result = [>::Result; M]; + type Result = [>::Result; M]; - fn select_axis(inp: &[T; M], indices: &Self::Indices, out: &mut Self::Result) { + fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) { for m in 0..M { Self::select_axis(&inp[m], &indices[m], &mut out[m]); } } - fn select_add(inp: &mut [T; M], indices: &Self::Indices, out: &Self::Result) { + fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) { for m in 0..M { Self::select_add(&mut inp[m], &indices[m], &out[m]); } } } -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where - Self: DeviceSelect, + Self: DeviceSelect, { - type Indices = [>::Indices; M]; - type Result = [>::Result; M]; + type Result = [>::Result; M]; - fn select_axis(inp: &T, indices: &Self::Indices, out: &mut Self::Result) { + fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) { for m in 0..M { Self::select_axis(inp, &indices[m], &mut out[m]); } } - fn select_add(inp: &mut T, indices: &Self::Indices, out: &Self::Result) { + fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) { for m in 0..M { Self::select_add(inp, &indices[m], &out[m]); } diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index 1f6c316c7..17da6784d 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -38,15 +38,12 @@ use crate::prelude::*; /// // is the new size of the 1st axis. /// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); /// ``` -pub fn select( - t: T, - indices: &<::Device as DeviceSelect>::Indices, -) -> R +pub fn select(t: T, indices: &I) -> R where T: Tensor, + I: 'static + Clone, R: Tensor, - ::Device: DeviceSelect, - <::Device as DeviceSelect>::Indices: 'static, + ::Device: DeviceSelect, { let mut result: ::NoTape = TensorCreator::zeros(); ::Device::select_axis(t.data(), indices, result.mut_data()); @@ -66,15 +63,12 @@ macro_rules! tensor_impl { ($typename:ident, [$($Vs:tt),*]) => { impl<$(const $Vs: usize, )* H: Tape> $typename<$($Vs, )* H> { /// Calls [select()] on `self`. - pub fn select( - self, - indices: &<::Device as DeviceSelect<::Array, R::Array, Mode>>::Indices - ) -> R + pub fn select(self, indices: &I) -> R where Self: Tensor, + I: 'static + Clone, R: Tensor::Tape>, - ::Device: DeviceSelect<::Array, R::Array, Mode>, - <::Device as DeviceSelect<::Array, R::Array, Mode>>::Indices: 'static, + ::Device: DeviceSelect<::Array, I, Mode, Result = R::Array>, { select(self, indices) } @@ -96,8 +90,9 @@ mod tests { #[test] fn test_valid_selects_1d() { let _: Tensor0D = Tensor1D::<5>::zeros().select(&0); - let _: Tensor1D<3> = Tensor1D::<5>::zeros().select(&[1, 2, 3]); - let _: Tensor1D<10> = Tensor1D::<5>::zeros().select(&[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); + let _: Tensor1D<3> = Tensor1D::<5>::zeros().select::<_, _, Index>(&[1, 2, 3]); + let _: Tensor1D<10> = + Tensor1D::<5>::zeros().select::<_, _, Index>(&[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); } #[test] @@ -114,7 +109,7 @@ mod tests { fn test_select_1d_less_backward() { let mut rng = thread_rng(); let t: Tensor1D<5> = TensorCreator::randn(&mut rng); - let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, select_modes::Index>(&[0, 3]); + let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, _, Index>(&[0, 3]); assert_eq!(r.data(), &[t.data()[0], t.data()[3]]); let g = r.mean().backward(); assert_eq!(g.ref_gradient(&t), &[0.5, 0.0, 0.0, 0.5, 0.0]); @@ -125,7 +120,7 @@ mod tests { let mut rng = thread_rng(); let t: Tensor1D<5> = TensorCreator::randn(&mut rng); let _t = *t.data(); - let r: Tensor1D<8, OwnedTape> = t.trace().select(&[0, 1, 2, 3, 4, 2, 4, 4]); + let r: Tensor1D<8, OwnedTape> = t.trace().select::<_, _, Index>(&[0, 1, 2, 3, 4, 2, 4, 4]); assert_eq!( r.data(), &[_t[0], _t[1], _t[2], _t[3], _t[4], _t[2], _t[4], _t[4]] @@ -150,7 +145,7 @@ mod tests { #[test] fn test_select_last_2d() { let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]); - let r: Tensor1D<2, OwnedTape> = t.trace().select(&[1, 2]); + let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, _, Recurse>(&[1, 2]); assert_eq!(r.data(), &[2.0, -3.0]); let gradients = r.mean().backward(); assert_eq!( @@ -167,7 +162,9 @@ mod tests { [[-3.0, 2.0, -1.0], [-6.0, 5.0, -4.0]], [[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]], ]); - let r: Tensor2D<4, 2, OwnedTape> = t.trace().select(&[[0, 1], [2, 2], [1, 1], [0, 0]]); + let r: Tensor2D<4, 2, OwnedTape> = + t.trace() + .select::<_, _, Recurse>>(&[[0, 1], [2, 2], [1, 1], [0, 0]]); assert_eq!( r.data(), &[[1.0, 5.0], [-3.0, -6.0], [2.0, 5.0], [1.0, 4.0]] From 0643daa15efdde2d9bd32ae6439317d05f899071 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 08:17:48 -0400 Subject: [PATCH 08/14] Going back to Select1 trait for tensor_ops impls --- src/devices/mod.rs | 5 -- src/devices/select.rs | 43 ++++++++----- src/tensor_ops/select.rs | 128 ++++++++++++++++++++++----------------- 3 files changed, 100 insertions(+), 76 deletions(-) diff --git a/src/devices/mod.rs b/src/devices/mod.rs index f970098a7..4dfd9596b 100644 --- a/src/devices/mod.rs +++ b/src/devices/mod.rs @@ -20,13 +20,8 @@ pub use reduce_all::*; pub use reduce_axis::*; pub use select::*; -use std::marker::PhantomData; use std::ops::*; -pub struct Index; -pub struct Recurse(PhantomData<*const M>); -pub struct Broadcast(PhantomData<*const M>); - /// The CPU device pub struct Cpu; diff --git a/src/devices/select.rs b/src/devices/select.rs index 1f1badd06..d0446e624 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -16,8 +16,19 @@ //! Then all three arrays with have the same dimension as the 0th axis. //! Do a for loop over the 0th axis and recurse! -use super::{Broadcast, Cpu, ForEachElement, Index, Recurse}; +use super::{Cpu, ForEachElement}; use crate::arrays::CountElements; +use std::marker::PhantomData; + +pub(crate) struct Idx; +pub(crate) struct Rec(PhantomData<*const M>); +pub(crate) struct Bcst(PhantomData<*const M>); + +pub(crate) type SelectAx0 = Idx; +pub(crate) type SelectAx1 = Rec; +pub(crate) type SelectAx2 = Rec; +pub(crate) type SelectAx3 = Rec; +pub(crate) type BSelectAx0 = Bcst; pub trait DeviceSelect { type Result; @@ -29,7 +40,7 @@ pub trait DeviceSelect { fn select_add(inp: &mut T, indices: &I, out: &Self::Result); } -impl DeviceSelect<[T; M], usize, Index> for Cpu +impl DeviceSelect<[T; M], usize, Idx> for Cpu where Self: ForEachElement, T: Copy + CountElements, @@ -46,7 +57,7 @@ where } } -impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu +impl DeviceSelect<[T; M], [usize; Z], Idx> for Cpu where Self: ForEachElement, T: Copy + CountElements, @@ -66,7 +77,7 @@ where } } -impl DeviceSelect<[T; M], [I; M], Recurse> for Cpu +impl DeviceSelect<[T; M], [I; M], Rec> for Cpu where Self: DeviceSelect, { @@ -85,7 +96,7 @@ where } } -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where Self: DeviceSelect, { @@ -120,7 +131,7 @@ mod tests { fn test_select_1d_0z() { let a: [f32; 3] = [1.0f32, 2.0, 3.0]; let mut b: [f32; 6] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); + >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]); } @@ -138,7 +149,7 @@ mod tests { fn test_select_2d_0z() { let a = A_2D; let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 0, 1], &mut b); + >::select_axis(&a, &[0, 0, 1], &mut b); assert_eq!(b, [a[0], a[0], a[1]]); } @@ -146,7 +157,7 @@ mod tests { fn test_select_2d_1() { let a = A_2D; let mut b: [f32; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 1], &mut b); + >>::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -154,7 +165,7 @@ mod tests { fn test_select_2d_1z() { let a = A_2D; let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); + >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]); } @@ -163,7 +174,7 @@ mod tests { let a = [[1.0], [2.0]]; let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &i, &mut b); + >>::select_axis(&a, &i, &mut b); #[rustfmt::skip] assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]); } @@ -173,7 +184,7 @@ mod tests { let mut a = [[0.0; 3]; 2]; let b = [[1.0, 3.0], [5.0, 5.0]]; let i = [[0, 2], [1, 1]]; - >>::select_add(&mut a, &i, &b); + >>::select_add(&mut a, &i, &b); assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]); } @@ -196,7 +207,7 @@ mod tests { fn test_select_3d_0z() { let a = A_3D; let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); + >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]); } @@ -204,7 +215,7 @@ mod tests { fn test_select_3d_1() { let a = A_3D; let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 0, 1, 1], &mut b); + >>::select_axis(&a, &[0, 0, 1, 1], &mut b); assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]); } @@ -212,7 +223,7 @@ mod tests { fn test_select_3d_1z() { let a = A_3D; let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); + >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]); } @@ -220,7 +231,7 @@ mod tests { fn test_select_3d_2() { let a = A_3D; let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS; - >>>::select_axis( + >>>::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], &mut b, @@ -240,7 +251,7 @@ mod tests { fn test_select_3d_2z() { let a = A_3D; let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS; - >>>::select_axis( + >>>::select_axis( &a, &[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]], &mut b, diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index 17da6784d..8f1cafe08 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -10,35 +10,40 @@ use crate::prelude::*; /// 2. Select multiple values from an axis, which keeps the number /// of dimensions the same. You can select the same element multiple /// number of times. -/// -/// Select sub elements using [Self::Indices]. -/// The same element can be selected multiple times depending -/// on [Self::Indices]. -/// -/// Selecting single value from 2d tensors: -/// ```rust -/// # use dfdx::prelude::*; -/// // select a single element from the 0th axis -/// let _: Tensor1D<5> = Tensor2D::<3, 5>::zeros().select(&0); -/// -/// // select a single element from the 1st axis - number of elements is equal -/// // to the size of the 0th axis, and the usize values can be 0..5 -/// let _: Tensor1D<3> = Tensor2D::<3, 5>::zeros().select(&[0, 2, 4]); -///``` -/// -/// Selecting multiple values from 2d tensors: -/// ```rust -/// # use dfdx::prelude::*; -/// // select a multiple elements from the 0th axis. -/// // the number of indices is the new size of the 0th axis. -/// let _: Tensor2D<6, 5> = Tensor2D::<3, 5>::zeros().select(&[0, 1, 2, 0, 1, 2]); -/// -/// // select a multiple elements from the 1st axis. -/// // must have same number of elements as the 0th axis, and the number of indices -/// // is the new size of the 1st axis. -/// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); -/// ``` -pub fn select(t: T, indices: &I) -> R +pub trait Select1 { + type Indices: Clone; + + /// Select sub elements using [Self::Indices]. + /// The same element can be selected multiple times depending + /// on [Self::Indices]. + /// + /// Selecting single value from 2d tensors: + /// ```rust + /// # use dfdx::prelude::*; + /// // select a single element from the 0th axis + /// let _: Tensor1D<5> = Tensor2D::<3, 5>::zeros().select(&0); + /// + /// // select a single element from the 1st axis - number of elements is equal + /// // to the size of the 0th axis, and the usize values can be 0..5 + /// let _: Tensor1D<3> = Tensor2D::<3, 5>::zeros().select(&[0, 2, 4]); + ///``` + /// + /// Selecting multiple values from 2d tensors: + /// ```rust + /// # use dfdx::prelude::*; + /// // select a multiple elements from the 0th axis. + /// // the number of indices is the new size of the 0th axis. + /// let _: Tensor2D<6, 5> = Tensor2D::<3, 5>::zeros().select(&[0, 1, 2, 0, 1, 2]); + /// + /// // select a multiple elements from the 1st axis. + /// // must have same number of elements as the 0th axis, and the number of indices + /// // is the new size of the 1st axis. + /// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]); + /// ``` + fn select(self, indices: &Self::Indices) -> T; +} + +fn select(t: T, indices: &I) -> R where T: Tensor, I: 'static + Clone, @@ -59,28 +64,44 @@ where }) } -macro_rules! tensor_impl { - ($typename:ident, [$($Vs:tt),*]) => { -impl<$(const $Vs: usize, )* H: Tape> $typename<$($Vs, )* H> { - /// Calls [select()] on `self`. - pub fn select(self, indices: &I) -> R - where - Self: Tensor, - I: 'static + Clone, - R: Tensor::Tape>, - ::Device: DeviceSelect<::Array, I, Mode, Result = R::Array>, - { - select(self, indices) +macro_rules! impl_select { + ($Axis:expr, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { +impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy { + type Indices = $IndTy; + fn select(self, indices: &Self::Indices) -> $DstTy { + select::<_, _, _, $Mode>(self, indices) } } }; } -tensor_impl!(Tensor0D, []); -tensor_impl!(Tensor1D, [M]); -tensor_impl!(Tensor2D, [M, N]); -tensor_impl!(Tensor3D, [M, N, O]); -tensor_impl!(Tensor4D, [M, N, O, P]); +// 1d +impl_select!(-1, SelectAx0, Tensor1D, usize, Tensor0D, {M}); +impl_select!(-1, SelectAx0, Tensor1D, [usize; Z], Tensor1D, {M, Z}); + +// 2d +impl_select!(0, SelectAx0, Tensor2D, usize, Tensor1D, {M, N}); +impl_select!(0, SelectAx0, Tensor2D, [usize; Z], Tensor2D, {M, N, Z}); +impl_select!(-1, SelectAx1, Tensor2D, [usize; M], Tensor1D, {M, N}); +impl_select!(-1, SelectAx1, Tensor2D, [[usize; Z]; M], Tensor2D, {M, N, Z}); + +// 3d +impl_select!(0, SelectAx0, Tensor3D, usize, Tensor2D, {M, N, O}); +impl_select!(0, SelectAx0, Tensor3D, [usize; Z], Tensor3D, {M, N, O, Z}); +impl_select!(1, SelectAx1, Tensor3D, [usize; M], Tensor2D, {M, N, O}); +impl_select!(1, SelectAx1, Tensor3D, [[usize; Z]; M], Tensor3D, {M, N, O, Z}); +impl_select!(-1, SelectAx2, Tensor3D, [[usize; N]; M], Tensor2D, {M, N, O}); +impl_select!(-1, SelectAx2, Tensor3D, [[[usize; Z]; N]; M], Tensor3D, {M, N, O, Z}); + +// 4d +impl_select!(0, SelectAx0, Tensor4D, usize, Tensor3D, {M, N, O, P}); +impl_select!(0, SelectAx0, Tensor4D, [usize; Z], Tensor4D, {M, N, O, P, Z}); +impl_select!(1, SelectAx1, Tensor4D, [usize; M], Tensor3D, {M, N, O, P}); +impl_select!(1, SelectAx1, Tensor4D, [[usize; Z]; M], Tensor4D, {M, N, O, P, Z}); +impl_select!(2, SelectAx2, Tensor4D, [[usize; N]; M], Tensor3D, {M, N, O, P}); +impl_select!(2, SelectAx2, Tensor4D, [[[usize; Z]; N]; M], Tensor4D, {M, N, O, P, Z}); +impl_select!(-1, SelectAx3, Tensor4D, [[[usize; O]; N]; M], Tensor3D, {M, N, O, P}); +impl_select!(-1, SelectAx3, Tensor4D, [[[[usize; Z]; O]; N]; M], Tensor4D, {M, N, O, P, Z}); #[cfg(test)] mod tests { @@ -90,9 +111,8 @@ mod tests { #[test] fn test_valid_selects_1d() { let _: Tensor0D = Tensor1D::<5>::zeros().select(&0); - let _: Tensor1D<3> = Tensor1D::<5>::zeros().select::<_, _, Index>(&[1, 2, 3]); - let _: Tensor1D<10> = - Tensor1D::<5>::zeros().select::<_, _, Index>(&[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); + let _: Tensor1D<3> = Tensor1D::<5>::zeros().select(&[1, 2, 3]); + let _: Tensor1D<10> = Tensor1D::<5>::zeros().select(&[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); } #[test] @@ -109,7 +129,7 @@ mod tests { fn test_select_1d_less_backward() { let mut rng = thread_rng(); let t: Tensor1D<5> = TensorCreator::randn(&mut rng); - let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, _, Index>(&[0, 3]); + let r: Tensor1D<2, OwnedTape> = t.trace().select(&[0, 3]); assert_eq!(r.data(), &[t.data()[0], t.data()[3]]); let g = r.mean().backward(); assert_eq!(g.ref_gradient(&t), &[0.5, 0.0, 0.0, 0.5, 0.0]); @@ -120,7 +140,7 @@ mod tests { let mut rng = thread_rng(); let t: Tensor1D<5> = TensorCreator::randn(&mut rng); let _t = *t.data(); - let r: Tensor1D<8, OwnedTape> = t.trace().select::<_, _, Index>(&[0, 1, 2, 3, 4, 2, 4, 4]); + let r: Tensor1D<8, OwnedTape> = t.trace().select(&[0, 1, 2, 3, 4, 2, 4, 4]); assert_eq!( r.data(), &[_t[0], _t[1], _t[2], _t[3], _t[4], _t[2], _t[4], _t[4]] @@ -145,7 +165,7 @@ mod tests { #[test] fn test_select_last_2d() { let t: Tensor2D<2, 3> = Tensor2D::new([[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]); - let r: Tensor1D<2, OwnedTape> = t.trace().select::<_, _, Recurse>(&[1, 2]); + let r: Tensor1D<2, OwnedTape> = t.trace().select(&[1, 2]); assert_eq!(r.data(), &[2.0, -3.0]); let gradients = r.mean().backward(); assert_eq!( @@ -162,9 +182,7 @@ mod tests { [[-3.0, 2.0, -1.0], [-6.0, 5.0, -4.0]], [[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]], ]); - let r: Tensor2D<4, 2, OwnedTape> = - t.trace() - .select::<_, _, Recurse>>(&[[0, 1], [2, 2], [1, 1], [0, 0]]); + let r: Tensor2D<4, 2, OwnedTape> = t.trace().select(&[[0, 1], [2, 2], [1, 1], [0, 0]]); assert_eq!( r.data(), &[[1.0, 5.0], [-3.0, -6.0], [2.0, 5.0], [1.0, 4.0]] From de69068f1ae83c35d960405938cd86ca00fd01b5 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 08:34:31 -0400 Subject: [PATCH 09/14] Adding SelectBatchAx0 --- src/tensor_ops/select.rs | 49 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index 8f1cafe08..c300ad3e0 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -43,7 +43,7 @@ pub trait Select1 { fn select(self, indices: &Self::Indices) -> T; } -fn select(t: T, indices: &I) -> R +pub(crate) fn select(t: T, indices: &I) -> R where T: Tensor, I: 'static + Clone, @@ -103,6 +103,46 @@ impl_select!(2, SelectAx2, Tensor4D, [[[usize; Z]; N]; M], Tensor impl_select!(-1, SelectAx3, Tensor4D, [[[usize; O]; N]; M], Tensor3D, {M, N, O, P}); impl_select!(-1, SelectAx3, Tensor4D, [[[[usize; Z]; O]; N]; M], Tensor4D, {M, N, O, P, Z}); +/// Select batched values from axis 0, resulting in `T`. Equivalent +/// to `torch.select` and `torch.gather` from pytorch. +pub trait SelectBatchAx0 { + type Indices; + + /// Select sub elements using [Self::Indices]. + /// The same element can be selected multiple times depending + /// on [Self::Indices]. + /// + /// This results in a tensor 1 dimension larger than self. + /// + /// Selecting batch of values from a 1d tensor: + /// ```rust + /// # use dfdx::prelude::*; + /// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]); + ///``` + /// + /// Selecting batch of values from a 2d tensor: + /// ```rust + /// # use dfdx::prelude::*; + /// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[[0], [1]], [[2], [3]]]); + ///``` + fn select_batch(self, indices: &Self::Indices) -> T; +} + +macro_rules! impl_select_batch { + ($SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { +impl<$(const $Dims: usize, )* H: Tape> SelectBatchAx0<$DstTy> for $SrcTy { + type Indices = $IndTy; + fn select_batch(self, indices: &Self::Indices) -> $DstTy { + select::<_, _, _, BSelectAx0>(self, indices) + } +} + }; +} + +impl_select_batch!(Tensor1D, [[usize; Z]; B], Tensor2D, {M, B, Z}); +impl_select_batch!(Tensor2D, [[usize; Z]; B], Tensor3D, {M, N, B, Z}); +impl_select_batch!(Tensor3D, [[usize; Z]; B], Tensor4D, {M, N, O, B, Z}); + #[cfg(test)] mod tests { use super::*; @@ -115,6 +155,13 @@ mod tests { let _: Tensor1D<10> = Tensor1D::<5>::zeros().select(&[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); } + #[test] + fn test_valid_select_batches() { + let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]); + let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]); + let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [1]]); + } + #[test] fn test_select_1d_backward() { let mut rng = thread_rng(); From b5f902286d1b0ff4b4f67c1b5848885ab79cf381 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Fri, 9 Sep 2022 08:44:29 -0400 Subject: [PATCH 10/14] Adding test for select_batch --- src/tensor_ops/select.rs | 57 +++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index c300ad3e0..2a0f9c5bd 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -43,27 +43,6 @@ pub trait Select1 { fn select(self, indices: &Self::Indices) -> T; } -pub(crate) fn select(t: T, indices: &I) -> R -where - T: Tensor, - I: 'static + Clone, - R: Tensor, - ::Device: DeviceSelect, -{ - let mut result: ::NoTape = TensorCreator::zeros(); - ::Device::select_axis(t.data(), indices, result.mut_data()); - - #[allow(clippy::clone_on_copy)] - let i = indices.clone(); - - move_tape_and_add_backward_op(t, result, move |mut t, result, grads| { - let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); - ::Device::fill(t.mut_data(), &mut |v| *v = 0.0); - ::Device::select_add(t.mut_data(), &i, result_grad); - ::Device::add(t_grad, t.data()); - }) -} - macro_rules! impl_select { ($Axis:expr, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => { impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy { @@ -143,9 +122,31 @@ impl_select_batch!(Tensor1D, [[usize; Z]; B], Tensor2D, {M, B, Z} impl_select_batch!(Tensor2D, [[usize; Z]; B], Tensor3D, {M, N, B, Z}); impl_select_batch!(Tensor3D, [[usize; Z]; B], Tensor4D, {M, N, O, B, Z}); +pub(crate) fn select(t: T, indices: &I) -> R +where + T: Tensor, + I: 'static + Clone, + R: Tensor, + ::Device: DeviceSelect, +{ + let mut result: ::NoTape = TensorCreator::zeros(); + ::Device::select_axis(t.data(), indices, result.mut_data()); + + #[allow(clippy::clone_on_copy)] + let i = indices.clone(); + + move_tape_and_add_backward_op(t, result, move |mut t, result, grads| { + let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); + ::Device::fill(t.mut_data(), &mut |v| *v = 0.0); + ::Device::select_add(t.mut_data(), &i, result_grad); + ::Device::add(t_grad, t.data()); + }) +} + #[cfg(test)] mod tests { use super::*; + use crate::tests::assert_close; use rand::thread_rng; #[test] @@ -245,4 +246,18 @@ mod tests { ] ); } + + #[test] + fn test_select_batch_backwards() { + let mut rng = thread_rng(); + let t: Tensor2D<4, 5> = TensorCreator::randn(&mut rng); + let r: Tensor3D<2, 3, 5, _> = t.trace().select_batch(&[[2, 0, 3], [0, 0, 3]]); + let r0: Tensor2D<3, 5> = t.clone().select(&[2, 0, 3]); + let r1: Tensor2D<3, 5> = t.clone().select(&[0, 0, 3]); + assert_close(&r.data()[0], r0.data()); + assert_close(&r.data()[1], r1.data()); + + let g = r.sum().backward(); + assert_eq!(g.ref_gradient(&t), &[[3.; 5], [0.; 5], [1.; 5], [2.; 5]]); + } } From 47693819b86962bfaf1c344158ebcd7a49a1ba4a Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 12 Sep 2022 08:24:12 -0400 Subject: [PATCH 11/14] Fixing select test --- src/tensor_ops/select.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index 2a0f9c5bd..a9601a999 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -160,7 +160,7 @@ mod tests { fn test_valid_select_batches() { let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]); let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]); - let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [1]]); + let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [0]]); } #[test] From 81a77203dcfa9a765fe4788901dcda81130d688e Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 12 Sep 2022 08:30:24 -0400 Subject: [PATCH 12/14] Fixing doctest for select --- src/tensor_ops/select.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index a9601a999..7aad877d3 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -102,7 +102,7 @@ pub trait SelectBatchAx0 { /// Selecting batch of values from a 2d tensor: /// ```rust /// # use dfdx::prelude::*; - /// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[[0], [1]], [[2], [3]]]); + /// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]); ///``` fn select_batch(self, indices: &Self::Indices) -> T; } From 7b2c6890389a4bd4abf9d3e96ac9b32fc2e762a2 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 12 Sep 2022 08:44:16 -0400 Subject: [PATCH 13/14] Filling out docstrings for select --- src/devices/select.rs | 73 +++++++++++++++++++++++++--------------- src/tensor_ops/select.rs | 2 +- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/src/devices/select.rs b/src/devices/select.rs index d0446e624..0e3f30871 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -1,34 +1,51 @@ //! Implementations of selecting either 1 or Z elements from an axis of an nd array. //! //! # Implementation Details -//! There are three cases to handle: +//! There are four cases to handle: //! -//! ## Selecting 1 element from the 0th axis +//! ## Selecting 1 element from the 0th axis [select_modes::Index] //! //! Just index into input using the single index and assign to output. //! -//! ## Selecting Z elements from the 0th axis +//! ## Selecting Z elements from the 0th axis [select_modes::Index] //! //! Just index into input for each index and assing to `output[z]` //! -//! ## Selecting either 1 or Z elements from a non-zero axis +//! ## Selecting either 1 or Z elements from a non-zero axis [select_modes::Recurse] //! //! Then all three arrays with have the same dimension as the 0th axis. //! Do a for loop over the 0th axis and recurse! +//! +//! ## Broadcasted select [select_modes::Broadcast] +//! +//! In this case only the indices & output are indexed. The input is broadcasted by +//! not indexing into it. use super::{Cpu, ForEachElement}; use crate::arrays::CountElements; -use std::marker::PhantomData; -pub(crate) struct Idx; -pub(crate) struct Rec(PhantomData<*const M>); -pub(crate) struct Bcst(PhantomData<*const M>); +/// Used to disambiguate trait implementations. Callees +/// must specify what kind of selection is occurring. +pub(crate) mod select_modes { + use std::marker::PhantomData; + + /// Select the current axis. + pub struct Index; + + /// Recurse the current axis. + pub struct Recurse(PhantomData<*const M>); + + /// Broadcast the current axis of input and recurse the indices. + pub struct Broadcast(PhantomData<*const M>); +} + +use select_modes::{Broadcast, Index, Recurse}; -pub(crate) type SelectAx0 = Idx; -pub(crate) type SelectAx1 = Rec; -pub(crate) type SelectAx2 = Rec; -pub(crate) type SelectAx3 = Rec; -pub(crate) type BSelectAx0 = Bcst; +pub(crate) type SelectAx0 = select_modes::Index; +pub(crate) type SelectAx1 = select_modes::Recurse; +pub(crate) type SelectAx2 = select_modes::Recurse; +pub(crate) type SelectAx3 = select_modes::Recurse; +pub(crate) type BSelectAx1 = select_modes::Broadcast; pub trait DeviceSelect { type Result; @@ -40,7 +57,7 @@ pub trait DeviceSelect { fn select_add(inp: &mut T, indices: &I, out: &Self::Result); } -impl DeviceSelect<[T; M], usize, Idx> for Cpu +impl DeviceSelect<[T; M], usize, Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, @@ -57,7 +74,7 @@ where } } -impl DeviceSelect<[T; M], [usize; Z], Idx> for Cpu +impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu where Self: ForEachElement, T: Copy + CountElements, @@ -77,7 +94,7 @@ where } } -impl DeviceSelect<[T; M], [I; M], Rec> for Cpu +impl DeviceSelect<[T; M], [I; M], Recurse> for Cpu where Self: DeviceSelect, { @@ -96,7 +113,7 @@ where } } -impl DeviceSelect> for Cpu +impl DeviceSelect> for Cpu where Self: DeviceSelect, { @@ -131,7 +148,7 @@ mod tests { fn test_select_1d_0z() { let a: [f32; 3] = [1.0f32, 2.0, 3.0]; let mut b: [f32; 6] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); + >::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b); assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]); } @@ -149,7 +166,7 @@ mod tests { fn test_select_2d_0z() { let a = A_2D; let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 0, 1], &mut b); + >::select_axis(&a, &[0, 0, 1], &mut b); assert_eq!(b, [a[0], a[0], a[1]]); } @@ -157,7 +174,7 @@ mod tests { fn test_select_2d_1() { let a = A_2D; let mut b: [f32; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 1], &mut b); + >>::select_axis(&a, &[0, 1], &mut b); assert_eq!(b, [1.0, 5.0]); } @@ -165,7 +182,7 @@ mod tests { fn test_select_2d_1z() { let a = A_2D; let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); + >>::select_axis(&a, &[[0, 2], [1, 1]], &mut b); assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]); } @@ -174,7 +191,7 @@ mod tests { let a = [[1.0], [2.0]]; let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]]; let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &i, &mut b); + >>::select_axis(&a, &i, &mut b); #[rustfmt::skip] assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]); } @@ -184,7 +201,7 @@ mod tests { let mut a = [[0.0; 3]; 2]; let b = [[1.0, 3.0], [5.0, 5.0]]; let i = [[0, 2], [1, 1]]; - >>::select_add(&mut a, &i, &b); + >>::select_add(&mut a, &i, &b); assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]); } @@ -207,7 +224,7 @@ mod tests { fn test_select_3d_0z() { let a = A_3D; let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS; - >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); + >::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b); assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]); } @@ -215,7 +232,7 @@ mod tests { fn test_select_3d_1() { let a = A_3D; let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[0, 0, 1, 1], &mut b); + >>::select_axis(&a, &[0, 0, 1, 1], &mut b); assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]); } @@ -223,7 +240,7 @@ mod tests { fn test_select_3d_1z() { let a = A_3D; let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS; - >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); + >>::select_axis(&a, &[[0], [0], [1], [1]], &mut b); assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]); } @@ -231,7 +248,7 @@ mod tests { fn test_select_3d_2() { let a = A_3D; let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS; - >>>::select_axis( + >>>::select_axis( &a, &[[1, 0], [0, 1], [0, 0], [1, 1]], &mut b, @@ -251,7 +268,7 @@ mod tests { fn test_select_3d_2z() { let a = A_3D; let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS; - >>>::select_axis( + >>>::select_axis( &a, &[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]], &mut b, diff --git a/src/tensor_ops/select.rs b/src/tensor_ops/select.rs index 7aad877d3..2ebd536a0 100644 --- a/src/tensor_ops/select.rs +++ b/src/tensor_ops/select.rs @@ -112,7 +112,7 @@ macro_rules! impl_select_batch { impl<$(const $Dims: usize, )* H: Tape> SelectBatchAx0<$DstTy> for $SrcTy { type Indices = $IndTy; fn select_batch(self, indices: &Self::Indices) -> $DstTy { - select::<_, _, _, BSelectAx0>(self, indices) + select::<_, _, _, BSelectAx1>(self, indices) } } }; From f87fa5cc0857d12eba865aafbca1481f78af79ad Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 12 Sep 2022 08:46:29 -0400 Subject: [PATCH 14/14] Adding comments --- src/devices/select.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/devices/select.rs b/src/devices/select.rs index 0e3f30871..b397ce5a4 100644 --- a/src/devices/select.rs +++ b/src/devices/select.rs @@ -47,6 +47,7 @@ pub(crate) type SelectAx2 = select_modes::Recurse; pub(crate) type SelectAx3 = select_modes::Recurse; pub(crate) type BSelectAx1 = select_modes::Broadcast; +/// Select values from `T` using indices `I`. `Mode` is used to disambiguate the impl. pub trait DeviceSelect { type Result; @@ -57,6 +58,7 @@ pub trait DeviceSelect { fn select_add(inp: &mut T, indices: &I, out: &Self::Result); } +// Select 1 element from 0th axis. impl DeviceSelect<[T; M], usize, Index> for Cpu where Self: ForEachElement, @@ -74,6 +76,7 @@ where } } +// Select Z elements from 0th axis. impl DeviceSelect<[T; M], [usize; Z], Index> for Cpu where Self: ForEachElement, @@ -94,6 +97,7 @@ where } } +// Select elements from non-zero axis impl DeviceSelect<[T; M], [I; M], Recurse> for Cpu where Self: DeviceSelect, @@ -113,6 +117,7 @@ where } } +// Broadcast select elements from non-zero axis. impl DeviceSelect> for Cpu where Self: DeviceSelect,