Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Renaming SelectTo, using SelectTo for batched select #217

Merged
merged 1 commit into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/10-tensor-index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use dfdx::arrays::HasArrayData;
use dfdx::tensor::{tensor, Tensor2D, Tensor3D};
use dfdx::tensor_ops::Select1;
use dfdx::tensor_ops::SelectTo;

fn main() {
let a: Tensor3D<3, 2, 3> = tensor([
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
//!
//! # Selects/Indexing
//!
//! Selecting or indexing into a tensor is done via [Select1::select()]. This traits enables
//! Selecting or indexing into a tensor is done via [SelectTo::select()]. This traits enables
//! 2 behaviors for each axis of a given tensor:
//!
//! 1. Select exactly 1 element from that axis.
Expand Down
119 changes: 48 additions & 71 deletions src/tensor_ops/select.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::utils::move_tape_and_add_backward_op;
use crate::devices::{
BSelectAx1, Device, DeviceSelect, FillElements, SelectAx0, SelectAx1, SelectAx2, SelectAx3,
};
use crate::devices::*;
use crate::gradients::Tape;
use crate::prelude::*;

/// Select values along a single axis `I` resulting in `T`. Equivalent
/// Select values along `Axes` resulting in `T`. Equivalent
/// to `torch.select` and `torch.gather` from pytorch.
///
/// There are two ways to select:
Expand All @@ -14,7 +12,9 @@ use crate::prelude::*;
/// 2. Select multiple values from an axis, which keeps the number
/// of dimensions the same. You can select the same element multiple
/// number of times.
pub trait Select1<T, const I: isize> {
///
/// You can also select batches of data with this trait.
pub trait SelectTo<T, Axes> {
type Indices: Clone;

/// Select sub elements using [Self::Indices].
Expand Down Expand Up @@ -44,12 +44,24 @@ pub trait Select1<T, const I: isize> {
/// // is the new size of the 1st axis.
/// let _: Tensor2D<3, 2> = Tensor2D::<3, 5>::zeros().select(&[[0, 4], [1, 3], [2, 2]]);
/// ```
///
/// Selecting batch of values from a 1d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]);
///```
///
/// Selecting batch of values from a 2d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]);
///```
fn select(self, indices: &Self::Indices) -> T;
}

macro_rules! impl_select {
($Axis:expr, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy {
($Axes:ty, $Mode:ty, $SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> SelectTo<$DstTy, $Axes> for $SrcTy {
type Indices = $IndTy;
fn select(self, indices: &Self::Indices) -> $DstTy {
select::<_, _, _, $Mode>(self, indices)
Expand All @@ -59,72 +71,37 @@ impl<$(const $Dims: usize, )* H: Tape> Select1<$DstTy, $Axis> for $SrcTy {
}

// 1d
impl_select!(-1, SelectAx0, Tensor1D<M, H>, usize, Tensor0D<H>, {M});
impl_select!(-1, SelectAx0, Tensor1D<M, H>, [usize; Z], Tensor1D<Z, H>, {M, Z});
impl_select!(Axis<0>, SelectAx0, Tensor1D<M, H>, usize, Tensor0D<H>, {M});
impl_select!(Axis<0>, SelectAx0, Tensor1D<M, H>, [usize; Z], Tensor1D<Z, H>, {M, Z});

// 2d
impl_select!(0, SelectAx0, Tensor2D<M, N, H>, usize, Tensor1D<N, H>, {M, N});
impl_select!(0, SelectAx0, Tensor2D<M, N, H>, [usize; Z], Tensor2D<Z, N, H>, {M, N, Z});
impl_select!(-1, SelectAx1, Tensor2D<M, N, H>, [usize; M], Tensor1D<M, H>, {M, N});
impl_select!(-1, SelectAx1, Tensor2D<M, N, H>, [[usize; Z]; M], Tensor2D<M, Z, H>, {M, N, Z});
impl_select!(Axis<0>, SelectAx0, Tensor2D<M, N, H>, usize, Tensor1D<N, H>, {M, N});
impl_select!(Axis<0>, SelectAx0, Tensor2D<M, N, H>, [usize; Z], Tensor2D<Z, N, H>, {M, N, Z});
impl_select!(Axis<1>, SelectAx1, Tensor2D<M, N, H>, [usize; M], Tensor1D<M, H>, {M, N});
impl_select!(Axis<1>, SelectAx1, Tensor2D<M, N, H>, [[usize; Z]; M], Tensor2D<M, Z, H>, {M, N, Z});

// 3d
impl_select!(0, SelectAx0, Tensor3D<M, N, O, H>, usize, Tensor2D<N, O, H>, {M, N, O});
impl_select!(0, SelectAx0, Tensor3D<M, N, O, H>, [usize; Z], Tensor3D<Z, N, O, H>, {M, N, O, Z});
impl_select!(1, SelectAx1, Tensor3D<M, N, O, H>, [usize; M], Tensor2D<M, O, H>, {M, N, O});
impl_select!(1, SelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; M], Tensor3D<M, Z, O, H>, {M, N, O, Z});
impl_select!(-1, SelectAx2, Tensor3D<M, N, O, H>, [[usize; N]; M], Tensor2D<M, N, H>, {M, N, O});
impl_select!(-1, SelectAx2, Tensor3D<M, N, O, H>, [[[usize; Z]; N]; M], Tensor3D<M, N, Z, H>, {M, N, O, Z});
impl_select!(Axis<0>, SelectAx0, Tensor3D<M, N, O, H>, usize, Tensor2D<N, O, H>, {M, N, O});
impl_select!(Axis<0>, SelectAx0, Tensor3D<M, N, O, H>, [usize; Z], Tensor3D<Z, N, O, H>, {M, N, O, Z});
impl_select!(Axis<1>, SelectAx1, Tensor3D<M, N, O, H>, [usize; M], Tensor2D<M, O, H>, {M, N, O});
impl_select!(Axis<1>, SelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; M], Tensor3D<M, Z, O, H>, {M, N, O, Z});
impl_select!(Axis<2>, SelectAx2, Tensor3D<M, N, O, H>, [[usize; N]; M], Tensor2D<M, N, H>, {M, N, O});
impl_select!(Axis<2>, SelectAx2, Tensor3D<M, N, O, H>, [[[usize; Z]; N]; M], Tensor3D<M, N, Z, H>, {M, N, O, Z});

// 4d
impl_select!(0, SelectAx0, Tensor4D<M, N, O, P, H>, usize, Tensor3D<N, O, P, H>, {M, N, O, P});
impl_select!(0, SelectAx0, Tensor4D<M, N, O, P, H>, [usize; Z], Tensor4D<Z, N, O, P, H>, {M, N, O, P, Z});
impl_select!(1, SelectAx1, Tensor4D<M, N, O, P, H>, [usize; M], Tensor3D<M, O, P, H>, {M, N, O, P});
impl_select!(1, SelectAx1, Tensor4D<M, N, O, P, H>, [[usize; Z]; M], Tensor4D<M, Z, O, P, H>, {M, N, O, P, Z});
impl_select!(2, SelectAx2, Tensor4D<M, N, O, P, H>, [[usize; N]; M], Tensor3D<M, N, P, H>, {M, N, O, P});
impl_select!(2, SelectAx2, Tensor4D<M, N, O, P, H>, [[[usize; Z]; N]; M], Tensor4D<M, N, Z, P, H>, {M, N, O, P, Z});
impl_select!(-1, SelectAx3, Tensor4D<M, N, O, P, H>, [[[usize; O]; N]; M], Tensor3D<M, N, O, H>, {M, N, O, P});
impl_select!(-1, SelectAx3, Tensor4D<M, N, O, P, H>, [[[[usize; Z]; O]; N]; M], Tensor4D<M, N, O, Z, H>, {M, N, O, P, Z});

/// Select batched values from axis 0, resulting in `T`. Equivalent
/// to `torch.select` and `torch.gather` from pytorch.
pub trait SelectBatchAx0<T> {
type Indices;

/// Select sub elements using [Self::Indices].
/// The same element can be selected multiple times depending
/// on [Self::Indices].
///
/// This results in a tensor 1 dimension larger than self.
///
/// Selecting batch of values from a 1d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]);
///```
///
/// Selecting batch of values from a 2d tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]);
///```
fn select_batch(self, indices: &Self::Indices) -> T;
}

macro_rules! impl_select_batch {
($SrcTy:ty, $IndTy:tt, $DstTy:ty, {$($Dims:tt),*}) => {
impl<$(const $Dims: usize, )* H: Tape> SelectBatchAx0<$DstTy> for $SrcTy {
type Indices = $IndTy;
fn select_batch(self, indices: &Self::Indices) -> $DstTy {
select::<_, _, _, BSelectAx1>(self, indices)
}
}
};
}

impl_select_batch!(Tensor1D<M, H>, [[usize; Z]; B], Tensor2D<B, Z, H>, {M, B, Z});
impl_select_batch!(Tensor2D<M, N, H>, [[usize; Z]; B], Tensor3D<B, Z, N, H>, {M, N, B, Z});
impl_select_batch!(Tensor3D<M, N, O, H>, [[usize; Z]; B], Tensor4D<B, Z, N, O, H>, {M, N, O, B, Z});
impl_select!(Axis<0>, SelectAx0, Tensor4D<M, N, O, P, H>, usize, Tensor3D<N, O, P, H>, {M, N, O, P});
impl_select!(Axis<0>, SelectAx0, Tensor4D<M, N, O, P, H>, [usize; Z], Tensor4D<Z, N, O, P, H>, {M, N, O, P, Z});
impl_select!(Axis<1>, SelectAx1, Tensor4D<M, N, O, P, H>, [usize; M], Tensor3D<M, O, P, H>, {M, N, O, P});
impl_select!(Axis<1>, SelectAx1, Tensor4D<M, N, O, P, H>, [[usize; Z]; M], Tensor4D<M, Z, O, P, H>, {M, N, O, P, Z});
impl_select!(Axis<2>, SelectAx2, Tensor4D<M, N, O, P, H>, [[usize; N]; M], Tensor3D<M, N, P, H>, {M, N, O, P});
impl_select!(Axis<2>, SelectAx2, Tensor4D<M, N, O, P, H>, [[[usize; Z]; N]; M], Tensor4D<M, N, Z, P, H>, {M, N, O, P, Z});
impl_select!(Axis<3>, SelectAx3, Tensor4D<M, N, O, P, H>, [[[usize; O]; N]; M], Tensor3D<M, N, O, H>, {M, N, O, P});
impl_select!(Axis<3>, SelectAx3, Tensor4D<M, N, O, P, H>, [[[[usize; Z]; O]; N]; M], Tensor4D<M, N, O, Z, H>, {M, N, O, P, Z});

// batched select
impl_select!(Axis<0>, BSelectAx1, Tensor1D<M, H>, [[usize; Z]; B], Tensor2D<B, Z, H>, {M, B, Z});
impl_select!(Axis<0>, BSelectAx1, Tensor2D<M, N, H>, [[usize; Z]; B], Tensor3D<B, Z, N, H>, {M, N, B, Z});
impl_select!(Axis<0>, BSelectAx1, Tensor3D<M, N, O, H>, [[usize; Z]; B], Tensor4D<B, Z, N, O, H>, {M, N, O, B, Z});

pub(crate) fn select<T, I, R, Mode>(t: T, indices: &I) -> R
where
Expand Down Expand Up @@ -162,9 +139,9 @@ mod tests {

#[test]
fn test_valid_select_batches() {
let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select_batch(&[[0], [1]]);
let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select_batch(&[[0], [1]]);
let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select_batch(&[[0], [0]]);
let _: Tensor2D<2, 1> = Tensor1D::<5>::zeros().select(&[[0], [1]]);
let _: Tensor3D<2, 1, 5> = Tensor2D::<3, 5>::zeros().select(&[[0], [1]]);
let _: Tensor4D<2, 1, 3, 5> = Tensor3D::<1, 3, 5>::zeros().select(&[[0], [0]]);
}

#[test]
Expand Down Expand Up @@ -255,7 +232,7 @@ mod tests {
fn test_select_batch_backwards() {
let mut rng = thread_rng();
let t: Tensor2D<4, 5> = TensorCreator::randn(&mut rng);
let r: Tensor3D<2, 3, 5, _> = t.trace().select_batch(&[[2, 0, 3], [0, 0, 3]]);
let r: Tensor3D<2, 3, 5, _> = t.trace().select(&[[2, 0, 3], [0, 0, 3]]);
let r0: Tensor2D<3, 5> = t.clone().select(&[2, 0, 3]);
let r1: Tensor2D<3, 5> = t.clone().select(&[0, 0, 3]);
assert_close(&r.data()[0], r0.data());
Expand Down