Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Batched Select for devices and tensor_ops #182

Merged
merged 14 commits into from
Sep 18, 2022
201 changes: 122 additions & 79 deletions src/devices/select.rs
Original file line number Diff line number Diff line change
@@ -1,107 +1,140 @@
//! Implementations of selecting either 1 or Z elements from an axis of an nd array.
//!
//! # Implementation Details
//! There are three cases to handle:
//! There are four cases to handle:
//!
//! ## Selecting 1 element from the 0th axis
//! ## Selecting 1 element from the 0th axis [select_modes::Index]
//!
//! Just index into input using the single index and assign to output.
//!
//! ## Selecting Z elements from the 0th axis
//! ## Selecting Z elements from the 0th axis [select_modes::Index]
//!
//! Just index into input for each index and assing to `output[z]`
//!
//! ## Selecting either 1 or Z elements from a non-zero axis
//! ## Selecting either 1 or Z elements from a non-zero axis [select_modes::Recurse]
//!
//! Then all three arrays with have the same dimension as the 0th axis.
//! Do a for loop over the 0th axis and recurse!
//!
//! ## Broadcasted select [select_modes::Broadcast]
//!
//! In this case only the indices & output are indexed. The input is broadcasted by
//! not indexing into it.

use super::{Cpu, ForEachElement};
use crate::arrays::CountElements;

/// Select values from `T` using `Indices` and producing `R` along a single `AXIS`.
pub trait SelectAlongAxis<T: CountElements, Indices, R: CountElements, const AXIS: isize> {
/// Used to disambiguate trait implementations. Callees
/// must specify what kind of selection is occurring.
pub(crate) mod select_modes {
use std::marker::PhantomData;

/// Select the current axis.
pub struct Index;

/// Recurse the current axis.
pub struct Recurse<M>(PhantomData<*const M>);

/// Broadcast the current axis of input and recurse the indices.
pub struct Broadcast<M>(PhantomData<*const M>);
}

use select_modes::{Broadcast, Index, Recurse};

pub(crate) type SelectAx0 = select_modes::Index;
pub(crate) type SelectAx1 = select_modes::Recurse<SelectAx0>;
pub(crate) type SelectAx2 = select_modes::Recurse<SelectAx1>;
pub(crate) type SelectAx3 = select_modes::Recurse<SelectAx2>;
pub(crate) type BSelectAx1 = select_modes::Broadcast<SelectAx0>;

/// Select values from `T` using indices `I`. `Mode` is used to disambiguate the impl.
pub trait DeviceSelect<T, I, Mode> {
type Result;

/// Equivalent to psuedocode `out = inp[indices]`
fn select_axis(inp: &T, indices: &Indices, out: &mut R);
fn select_axis(inp: &T, indices: &I, out: &mut Self::Result);

/// `inp[indices] += out`
fn select_add(inp: &mut T, indices: &Indices, out: &R);
fn select_add(inp: &mut T, indices: &I, out: &Self::Result);
}

macro_rules! select_01 {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, usize, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &usize, out: &mut $DstTy) {
// Select 1 element from 0th axis.
impl<T, const M: usize> DeviceSelect<[T; M], usize, Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = T;

fn select_axis(inp: &[T; M], indices: &usize, out: &mut Self::Result) {
*out = inp[*indices];
}
fn select_add(inp: &mut $SrcTy, indices: &usize, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &usize, out: &Self::Result) {
Self::foreach_mr(&mut inp[*indices], out, &mut |a, b| *a += b);
}
}
};
}

macro_rules! select_0z {
($Axis:expr, $SrcTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, [usize; Z], $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &[usize; Z], out: &mut $DstTy) {
// Select Z elements from 0th axis.
impl<T, const M: usize, const Z: usize> DeviceSelect<[T; M], [usize; Z], Index> for Cpu
where
Self: ForEachElement<T>,
T: Copy + CountElements,
T::Dtype: for<'a> std::ops::AddAssign<&'a T::Dtype>,
{
type Result = [T; Z];

fn select_axis(inp: &[T; M], indices: &[usize; Z], out: &mut Self::Result) {
for z in 0..Z {
out[z] = inp[indices[z]];
}
}
fn select_add(inp: &mut $SrcTy, indices: &[usize; Z], out: &$DstTy) {
fn select_add(inp: &mut [T; M], indices: &[usize; Z], out: &Self::Result) {
for z in 0..Z {
Self::foreach_mr(&mut inp[indices[z]], &out[z], &mut |a, b| *a += b);
}
}
}
};
}

macro_rules! select_nz {
($Axis:expr, $SrcTy:tt, $IndTy:tt, $DstTy:tt, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize),*> SelectAlongAxis<$SrcTy, $IndTy, $DstTy, $Axis> for Cpu {
fn select_axis(inp: &$SrcTy, indices: &$IndTy, out: &mut $DstTy) {
// Select elements from non-zero axis
impl<T, I, const M: usize, SubMode> DeviceSelect<[T; M], [I; M], Recurse<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &[T; M], indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(&inp[m], &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut $SrcTy, indices: &$IndTy, out: &$DstTy) {

fn select_add(inp: &mut [T; M], indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(&mut inp[m], &indices[m], &out[m]);
}
}
}
};
}

// 1d
select_01!(-1, [f32; M], f32, { M });
select_0z!(-1, [f32; M], [f32; Z], {M, Z});

// 2d
select_01!(0, [[f32; N]; M], [f32; N], {M, N});
select_0z!(0, [[f32; N]; M], [[f32; N]; Z], {M, N, Z});
select_nz!(-1, [[f32; N]; M], [usize; M], [f32; M], {M, N});
select_nz!(-1, [[f32; N]; M], [[usize; Z]; M], [[f32; Z]; M], {M, N, Z});

// 3d
select_01!(0, [[[f32; O]; N]; M], [[f32; O]; N], {M, N, O});
select_0z!(0, [[[f32; O]; N]; M], [[[f32; O]; N]; Z], {M, N, O, Z});
select_nz!(1, [[[f32; O]; N]; M], [usize; M], [[f32; O]; M], {M, N, O});
select_nz!(1, [[[f32; O]; N]; M], [[usize; Z]; M], [[[f32; O]; Z]; M], {M, N, O, Z});
select_nz!(-1, [[[f32; O]; N]; M], [[usize; N]; M], [[f32; N]; M], {M, N, O});
select_nz!(-1, [[[f32; O]; N]; M], [[[usize; Z]; N]; M], [[[f32; Z]; N]; M], {M, N, O, Z});

// 4d
select_01!(0, [[[[f32; P]; O]; N]; M], [[[f32; P]; O]; N], {M, N, O, P});
select_0z!(0, [[[[f32; P]; O]; N]; M], [[[[f32; P]; O]; N]; Z], {M, N, O, P, Z});
select_nz!(1, [[[[f32; P]; O]; N]; M], [usize; M], [[[f32; P]; O]; M], {M, N, O, P});
select_nz!(1, [[[[f32; P]; O]; N]; M], [[usize; Z]; M], [[[[f32; P]; O]; Z]; M], {M, N, O, P, Z});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[usize; N]; M], [[[f32; P]; N]; M], {M, N, O, P});
select_nz!(2, [[[[f32; P]; O]; N]; M], [[[usize; Z]; N]; M], [[[[f32; P]; Z]; N]; M], {M, N, O, P, Z});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[usize; O]; N]; M], [[[f32; O]; N]; M], {M, N, O, P});
select_nz!(-1, [[[[f32; P]; O]; N]; M], [[[[usize; Z]; O]; N]; M], [[[[f32; Z]; O]; N]; M], {M, N, O, P, Z});
// Broadcast select elements from non-zero axis.
impl<T, I, const M: usize, SubMode> DeviceSelect<T, [I; M], Broadcast<SubMode>> for Cpu
where
Self: DeviceSelect<T, I, SubMode>,
{
type Result = [<Self as DeviceSelect<T, I, SubMode>>::Result; M];

fn select_axis(inp: &T, indices: &[I; M], out: &mut Self::Result) {
for m in 0..M {
Self::select_axis(inp, &indices[m], &mut out[m]);
}
}
fn select_add(inp: &mut T, indices: &[I; M], out: &Self::Result) {
for m in 0..M {
Self::select_add(inp, &indices[m], &out[m]);
}
}
}

#[cfg(test)]
mod tests {
Expand All @@ -110,17 +143,17 @@ mod tests {

#[test]
fn test_select_1d_0() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
let a: [f32; 3] = [1.0, 2.0, 3.0];
let mut b: f32 = ZeroElements::ZEROS;
Cpu::select_axis(&a, &1, &mut b);
assert_eq!(b, 2.0);
}

#[test]
fn test_select_1d_0z() {
let a = [1.0, 2.0, 3.0];
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
let a: [f32; 3] = [1.0f32, 2.0, 3.0];
let mut b: [f32; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 1, 2, 2, 1, 0], &mut b);
assert_eq!(b, [1.0, 2.0, 3.0, 3.0, 2.0, 1.0]);
}

Expand All @@ -129,41 +162,51 @@ mod tests {
#[test]
fn test_select_2d_0() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
let mut b: [f32; 3] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, [1.0, 2.0, 3.0]);
}

#[test]
fn test_select_2d_0z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1], &mut b);
let mut b: [[f32; 3]; 3] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1], &mut b);
assert_eq!(b, [a[0], a[0], a[1]]);
}

#[test]
fn test_select_2d_1() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(&a, &[0, 1], &mut b);
let mut b: [f32; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 1], &mut b);
assert_eq!(b, [1.0, 5.0]);
}

#[test]
fn test_select_2d_1z() {
let a = A_2D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
let mut b: [[f32; 2]; 2] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0, 2], [1, 1]], &mut b);
assert_eq!(b, [[1.0, 3.0], [5.0, 5.0]]);
}

#[test]
fn test_select_broadcast_2d() {
let a = [[1.0], [2.0]];
let i: [[usize; 3]; 4] = [[0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
let mut b: [[[f32; 1]; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Broadcast<Index>>>::select_axis(&a, &i, &mut b);
#[rustfmt::skip]
assert_eq!(b, [[[1.], [2.], [1.]], [[2.], [2.], [2.]], [[1.], [1.], [1.]], [[2.], [1.], [2.]]]);
}

#[test]
fn test_select_add_2d() {
let mut a = [[0.0; 3]; 2];
let b = [[1.0, 3.0], [5.0, 5.0]];
let i = [[0, 2], [1, 1]];
Cpu::select_add(&mut a, &i, &b);
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_add(&mut a, &i, &b);
assert_eq!(a, [[1.0, 0.0, 3.0], [0.0, 10.0, 0.0]]);
}

Expand All @@ -177,40 +220,40 @@ mod tests {
#[test]
fn test_select_3d_0() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
let mut b: [[f32; 3]; 2] = ZeroElements::ZEROS;
Cpu::select_axis(&a, &0, &mut b);
assert_eq!(b, A_3D[0]);
}

#[test]
fn test_select_3d_0z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
Cpu::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
let mut b: [[[f32; 3]; 2]; 6] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Index>>::select_axis(&a, &[0, 0, 1, 2, 3, 3], &mut b);
assert_eq!(b, [A_3D[0], A_3D[0], A_3D[1], A_3D[2], A_3D[3], A_3D[3]]);
}

#[test]
fn test_select_3d_1() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
let mut b: [[f32; 3]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[0, 0, 1, 1], &mut b);
assert_eq!(b, [A_3D[0][0], A_3D[1][0], A_3D[2][1], A_3D[3][1]]);
}

#[test]
fn test_select_3d_1z() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, 1>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
let mut b: [[[f32; 3]; 1]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Index>>>::select_axis(&a, &[[0], [0], [1], [1]], &mut b);
assert_eq!(b, [[A_3D[0][0]], [A_3D[1][0]], [A_3D[2][1]], [A_3D[3][1]]]);
}

#[test]
fn test_select_3d_2() {
let a = A_3D;
let mut b = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
let mut b: [[f32; 2]; 4] = ZeroElements::ZEROS;
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[1, 0], [0, 1], [0, 0], [1, 1]],
&mut b,
Expand All @@ -230,7 +273,7 @@ mod tests {
fn test_select_3d_2z() {
let a = A_3D;
let mut b: [[[f32; 1]; 2]; 4] = ZeroElements::ZEROS;
<Cpu as SelectAlongAxis<_, _, _, -1>>::select_axis(
<Cpu as DeviceSelect<_, _, Recurse<Recurse<Index>>>>::select_axis(
&a,
&[[[1], [0]], [[0], [1]], [[0], [0]], [[1], [1]]],
&mut b,
Expand Down
Loading