Skip to content

Commit

Permalink
Renaming SelectTo, using SelectTo for batched select (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Oct 5, 2022
1 parent 5f587dd commit 09f8084
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 73 deletions.
2 changes: 1 addition & 1 deletion examples/10-tensor-index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use dfdx::arrays::HasArrayData;
use dfdx::tensor::{tensor, Tensor2D, Tensor3D};
use dfdx::tensor_ops::Select1;
use dfdx::tensor_ops::SelectTo;

fn main() {
let a: Tensor3D<3, 2, 3> = tensor([
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
//!
//! # Selects/Indexing
//!
//! Selecting or indexing into a tensor is done via [Select1::select()]. This traits enables
//! Selecting or indexing into a tensor is done via [SelectTo::select()]. This traits enables
//! 2 behaviors for each axis of a given tensor:
//!
//! 1. Select exactly 1 element from that axis.
Expand Down
119 changes: 48 additions & 71 deletions src/tensor_ops/select.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::utils::move_tape_and_add_backward_op;
use crate::devices::{
BSelectAx1, Device, DeviceSelect, FillElements, SelectAx0, SelectAx1, SelectAx2, SelectAx3,
};
use crate::devices::*;
use crate::gradients::Tape;
use crate::prelude::*;

/// Select values along a single axis `I` resulting in `T`. Equivalent
/// Select values along `Axes` resulting in `T`. Equivalent
/// to `torch.select` and `torch.gather` from pytorch.
///
/// There are two ways to select:
Expand All @@ -14,7 +12,9 @@ use crate::prelude::*;
/// 2. Select multiple values from an axis, which keeps the number
/// of dimensions the same. You can select the same element multiple
/// number of times.
pub trait Select1<T, const I: isize> {
///
/// You can also select batches of data with this trait.
pub trait SelectTo<T, Axes> {
type Indices: Clone;

/// Select sub elements using [Self::Indices].
Expand Down Expand Up @@ -44,12 +44,24 @@ pub trait Select1<T, const I: isize> {
/// // is the new size of the 1st axis.
/// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]);
/// ```
///
/// Selecting batch of values from a 1d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]);
///```
///
/// Selecting batch of values from a 2d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]);
///```
fn select(self, indices: &Self::Indices) -> T;
}

macro_rules! impl_select {
($Axis:expr, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy {
($Axes:ty, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> SelectTo<$DstTy, $Axes> for $SrcTy {
type Indices = $IndTy;
fn select(self, indices: &Self::Indices) -> $DstTy {
select::<_, _, _, $Mode>(self, indices)
Expand All @@ -59,72 +71,37 @@ impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy {
}

// 1d
impl_select!(-1, SelectAx0, Tensor1D<M, H>, usize, Tensor0D<H>, {M});
impl_select!(-1, SelectAx0, Tensor1D<M, H>, [usize; Z], Tensor1D<Z, H>, {M, Z});
impl_select!(Axis<0>, SelectAx0, Tensor1D<M, H>, usize, Tensor0D<H>, {M});
impl_select!(Axis<0>, SelectAx0, Tensor1D<M, H>, [usize; Z], Tensor1D<Z, H>, {M, Z});

// 2d
impl_select!(0, SelectAx0, Tensor2D<M, N, H>, usize, Tensor1D<N, H>, {M, N});
impl_select!(0, SelectAx0, Tensor2D<M, N, H>, [usize; Z], Tensor2D<Z, N, H>, {M, N, Z});
impl_select!(-1, SelectAx1, Tensor2D<M, N, H>, [usize; M], Tensor1D<M, H>, {M, N});
impl_select!(-1, SelectAx1, Tensor2D<M, N, H>, [[usize; Z]; M], Tensor2D<M, Z, H>, {M, N, Z});
impl_select!(Axis<0>, SelectAx0, Tensor2D<M, N, H>, usize, Tensor1D<N, H>, {M, N});
impl_select!(Axis<0>, SelectAx0, Tensor2D<M, N, H>, [usize; Z], Tensor2D<Z, N, H>, {M, N, Z});
impl_select!(Axis<1>, SelectAx1, Tensor2D<M, N, H>, [usize; M], Tensor1D<M, H>, {M, N});
impl_select!(Axis<1>, SelectAx1, Tensor2D<M, N, H>, [[usize; Z]; M], Tensor2D<M, Z, H>, {M, N, Z});

// 3d
impl_select!(0, SelectAx0, Tensor3D<M, N, O, H>, usize, Tensor2D<N, O, H>, {M, N, O});
impl_select!(0, SelectAx0, Tensor3D<M, N, O, H>, [usize; Z], Tensor3D<Z, N, O, H>, {M, N, O, Z});
impl_select!(1, SelectAx1, Tensor3D<M, N, O, H>, [usize; M], Tensor2D<M, O, H>, {M, N, O});
impl_select!(1, SelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; M], Tensor3D<M, Z, O, H>, {M, N, O, Z});
impl_select!(-1, SelectAx2, Tensor3D<M, N, O, H>, [[usize; N]; M], Tensor2D<M, N, H>, {M, N, O});
impl_select!(-1, SelectAx2, Tensor3D<M, N, O, H>, [[[usize; Z]; N]; M], Tensor3D<M, N, Z, H>, {M, N, O, Z});
impl_select!(Axis<0>, SelectAx0, Tensor3D<M, N, O, H>, usize, Tensor2D<N, O, H>, {M, N, O});
impl_select!(Axis<0>, SelectAx0, Tensor3D<M, N, O, H>, [usize; Z], Tensor3D<Z, N, O, H>, {M, N, O, Z});
impl_select!(Axis<1>, SelectAx1, Tensor3D<M, N, O, H>, [usize; M], Tensor2D<M, O, H>, {M, N, O});
impl_select!(Axis<1>, SelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; M], Tensor3D<M, Z, O, H>, {M, N, O, Z});
impl_select!(Axis<2>, SelectAx2, Tensor3D<M, N, O, H>, [[usize; N]; M], Tensor2D<M, N, H>, {M, N, O});
impl_select!(Axis<2>, SelectAx2, Tensor3D<M, N, O, H>, [[[usize; Z]; N]; M], Tensor3D<M, N, Z, H>, {M, N, O, Z});

// 4d
impl_select!(0, SelectAx0, Tensor4D<M, N, O, P, H>, usize, Tensor3D<N, O, P, H>, {M, N, O, P});
impl_select!(0, SelectAx0, Tensor4D<M, N, O, P, H>, [usize; Z], Tensor4D<Z, N, O, P, H>, {M, N, O, P, Z});
impl_select!(1, SelectAx1, Tensor4D<M, N, O, P, H>, [usize; M], Tensor3D<M, O, P, H>, {M, N, O, P});
impl_select!(1, SelectAx1, Tensor4D<M, N, O, P, H>, [[usize; Z]; M], Tensor4D<M, Z, O, P, H>, {M, N, O, P, Z});
impl_select!(2, SelectAx2, Tensor4D<M, N, O, P, H>, [[usize; N]; M], Tensor3D<M, N, P, H>, {M, N, O, P});
impl_select!(2, SelectAx2, Tensor4D<M, N, O, P, H>, [[[usize; Z]; N]; M], Tensor4D<M, N, Z, P, H>, {M, N, O, P, Z});
impl_select!(-1, SelectAx3, Tensor4D<M, N, O, P, H>, [[[usize; O]; N]; M], Tensor3D<M, N, O, H>, {M, N, O, P});
impl_select!(-1, SelectAx3, Tensor4D<M, N, O, P, H>, [[[[usize; Z]; O]; N]; M], Tensor4D<M, N, O, Z, H>, {M, N, O, P, Z});

/// Select batched values from axis 0, resulting in `T`. Equivalent
/// to `torch.select` and `torch.gather` from pytorch.
pub trait SelectBatchAx0<T> {
type Indices;

/// Select sub elements using [Self::Indices].
/// The same element can be selected multiple times depending
/// on [Self::Indices].
///
/// This results in a tensor 1 dimension larger than self.
///
/// Selecting batch of values from a 1d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]);
///```
///
/// Selecting batch of values from a 2d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]);
///```
fn select_batch(self, indices: &Self::Indices) -> T;
}

macro_rules! impl_select_batch {
($SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> SelectBatchAx0<$DstTy> for $SrcTy {
type Indices = $IndTy;
fn select_batch(self, indices: &Self::Indices) -> $DstTy {
select::<_, _, _, BSelectAx1>(self, indices)
}
}
};
}

impl_select_batch!(Tensor1D<M, H>, [[usize; Z]; B], Tensor2D<B, Z, H>, {M, B, Z});
impl_select_batch!(Tensor2D<M, N, H>, [[usize; Z]; B], Tensor3D<B, Z, N, H>, {M, N, B, Z});
impl_select_batch!(Tensor3D<M, N, O, H>, [[usize; Z]; B], Tensor4D<B, Z, N, O, H>, {M, N, O, B, Z});
impl_select!(Axis<0>, SelectAx0, Tensor4D<M, N, O, P, H>, usize, Tensor3D<N, O, P, H>, {M, N, O, P});
impl_select!(Axis<0>, SelectAx0, Tensor4D<M, N, O, P, H>, [usize; Z], Tensor4D<Z, N, O, P, H>, {M, N, O, P, Z});
impl_select!(Axis<1>, SelectAx1, Tensor4D<M, N, O, P, H>, [usize; M], Tensor3D<M, O, P, H>, {M, N, O, P});
impl_select!(Axis<1>, SelectAx1, Tensor4D<M, N, O, P, H>, [[usize; Z]; M], Tensor4D<M, Z, O, P, H>, {M, N, O, P, Z});
impl_select!(Axis<2>, SelectAx2, Tensor4D<M, N, O, P, H>, [[usize; N]; M], Tensor3D<M, N, P, H>, {M, N, O, P});
impl_select!(Axis<2>, SelectAx2, Tensor4D<M, N, O, P, H>, [[[usize; Z]; N]; M], Tensor4D<M, N, Z, P, H>, {M, N, O, P, Z});
impl_select!(Axis<3>, SelectAx3, Tensor4D<M, N, O, P, H>, [[[usize; O]; N]; M], Tensor3D<M, N, O, H>, {M, N, O, P});
impl_select!(Axis<3>, SelectAx3, Tensor4D<M, N, O, P, H>, [[[[usize; Z]; O]; N]; M], Tensor4D<M, N, O, Z, H>, {M, N, O, P, Z});

// batched select
impl_select!(Axis<0>, BSelectAx1, Tensor1D<M, H>, [[usize; Z]; B], Tensor2D<B, Z, H>, {M, B, Z});
impl_select!(Axis<0>, BSelectAx1, Tensor2D<M, N, H>, [[usize; Z]; B], Tensor3D<B, Z, N, H>, {M, N, B, Z});
impl_select!(Axis<0>, BSelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; B], Tensor4D<B, Z, N, O, H>, {M, N, O, B, Z});

pub(crate) fn select<T, I, R, Mode>(t: T, indices: &I) -> R
where
Expand Down Expand Up @@ -162,9 +139,9 @@ mod tests {

#[test]
fn test_valid_select_batches() {
let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]);
let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]);
let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [0]]);
let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]);
let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]);
let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select(&[[0], [0]]);
}

#[test]
Expand Down Expand Up @@ -255,7 +232,7 @@ mod tests {
fn test_select_batch_backwards() {
let mut rng = thread_rng();
let t: Tensor2D<4, 5> = TensorCreator::randn(&mut rng);
let r: Tensor3D<2, 3, 5, _> = t.trace().select_batch(&[[2, 0, 3], [0, 0, 3]]);
let r: Tensor3D<2, 3, 5, _> = t.trace().select(&[[2, 0, 3], [0, 0, 3]]);
let r0: Tensor2D<3, 5> = t.clone().select(&[2, 0, 3]);
let r1: Tensor2D<3, 5> = t.clone().select(&[0, 0, 3]);
assert_close(&r.data()[0], r0.data());
Expand Down

0 comments on commit 09f8084

Please sign in to comment.