diff --git a/src/shapes/broadcasts.rs b/src/shapes/broadcasts.rs index 596a2f60a..d4fe7c5e6 100644 --- a/src/shapes/broadcasts.rs +++ b/src/shapes/broadcasts.rs @@ -50,6 +50,8 @@ macro_rules! length { ($x:tt $($xs:tt)*) => {1 + length!($($xs)*)}; } +pub(crate) use length; + // Defines all reduce/broadcast rules recursively macro_rules! broadcast_to_all { ([$($s1:ident)*] [$($s2:ident)*] [$($ax:tt)*] [] [$axis:tt $($axes:tt)*]) => { diff --git a/src/shapes/mod.rs b/src/shapes/mod.rs index 028e762f5..0168da817 100644 --- a/src/shapes/mod.rs +++ b/src/shapes/mod.rs @@ -16,6 +16,7 @@ mod realize; mod replace_dim; mod same_numel; mod shape; +mod slice; pub(crate) use axes::Axes; pub(crate) use broadcasts::{ @@ -26,6 +27,7 @@ pub(crate) use realize::RealizeShapeTo; pub(crate) use replace_dim::{RemoveDimTo, ReplaceDimTo}; pub(crate) use same_numel::AssertSameNumel; +pub(crate) use slice::SliceShape; pub use axes::{Axes2, Axes3, Axes4, Axes5, Axes6, Axis, HasAxes}; pub use shape::{Array, Const, ConstDim, Dim}; diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index f218a2024..55662a83c 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -195,7 +195,8 @@ pub trait Shape: + Send + Sync + IntoIterator - + Into>; + + Into> + + AsRef<[usize]>; /// All the axes of this shape type AllAxes: Axes; diff --git a/src/shapes/slice.rs b/src/shapes/slice.rs new file mode 100644 index 000000000..10634314f --- /dev/null +++ b/src/shapes/slice.rs @@ -0,0 +1,119 @@ +use super::*; +use std::ops::{Bound, RangeBounds}; + +fn get_start_bound(bound: Bound<&usize>) -> usize { + match bound { + Bound::Included(x) => *x, + Bound::Excluded(x) => x + 1, + Bound::Unbounded => 0, + } +} + +fn get_end_bound(bound: Bound<&usize>, size: usize) -> usize { + match bound { + Bound::Excluded(x) => *x, + Bound::Included(x) => x + 1, + Bound::Unbounded => size, + } +} + +pub trait SliceDim>: Dim { + type Sliced: Dim; + + fn slice(&self, range: &R) -> Option { + let size = self.size(); + + let start_bound = get_start_bound(range.start_bound()); + let end_bound = get_end_bound(range.end_bound(), size); + + (end_bound <= size && start_bound <= end_bound) + .then_some(end_bound - start_bound) + .and_then(Self::Sliced::from_size) + } +} + +macro_rules! slice_dim_to_usize { + ($range:ty) => { + impl SliceDim<$range> for D { + type Sliced = usize; + } + }; +} + +slice_dim_to_usize!(std::ops::Range); +slice_dim_to_usize!(std::ops::RangeTo); +slice_dim_to_usize!(std::ops::RangeFrom); +slice_dim_to_usize!(std::ops::RangeInclusive); +slice_dim_to_usize!(std::ops::RangeToInclusive); + +impl SliceDim for D { + type Sliced = D; + + fn slice(&self, _: &std::ops::RangeFull) -> Option { + Some(*self) + } +} + +pub trait SliceShape: Shape { + type Sliced: Shape; + + fn slice(&self, range: &R) -> Option; + fn first_idx_in_slice(&self, range: &R) -> usize; +} + +impl SliceShape<()> for () { + type Sliced = Self; + + fn slice(&self, _range: &()) -> Option { + Some(()) + } + + fn first_idx_in_slice(&self, _range: &()) -> usize { + 0 + } +} + +use super::broadcasts::length; + +macro_rules! slice_shape { + ([$($dim:ident)*] [$($range:ident)*] [$($idx:tt)*]) => { + impl<$($dim: Dim),*, $($range: RangeBounds),*> SliceShape<($($range,)*)> for ($($dim,)*) + where + $($dim: SliceDim<$range>),* + { + type Sliced = ($($dim::Sliced,)*); + + fn slice(&self, range: &($($range,)*)) -> Option { + Some(($(self.$idx.slice(&range.$idx)?,)*)) + } + + fn first_idx_in_slice(&self, range: &($($range,)*)) -> usize { + let strides = self.strides(); + $(get_start_bound(range.$idx.start_bound()) * strides[$idx] + )* 0 + } + } + + impl<$($range: RangeBounds),*> SliceShape<($($range,)*)> for [usize; {length!($($range)*)}] + where + $(usize: SliceDim<$range>),* + { + type Sliced = ($(>::Sliced,)*); + + fn slice(&self, range: &($($range,)*)) -> Option { + Some(($(self[$idx].slice(&range.$idx)?,)*)) + } + + fn first_idx_in_slice(&self, range: &($($range,)*)) -> usize { + let strides = self.strides(); + $(get_start_bound(range.$idx.start_bound()) * strides[$idx] + )* 0 + } + } + } +} + +slice_shape!([D1][R1][0]); +slice_shape!([D1 D2] [R1 R2] [0 1]); +slice_shape!([D1 D2 D3] [R1 R2 R3] [0 1 2]); +slice_shape!([D1 D2 D3 D4] [R1 R2 R3 R4] [0 1 2 3]); +slice_shape!([D1 D2 D3 D4 D5] [R1 R2 R3 R4 R5] [0 1 2 3 4]); +slice_shape!([D1 D2 D3 D4 D5 D6] [R1 R2 R3 R4 R5 R6] [0 1 2 3 4 5]); diff --git a/src/tensor/cpu/iterate.rs b/src/tensor/cpu/iterate.rs index 4b9184999..ab15c0be0 100644 --- a/src/tensor/cpu/iterate.rs +++ b/src/tensor/cpu/iterate.rs @@ -25,6 +25,20 @@ impl NdIndex { } impl NdIndex { + pub(crate) fn get_strided_index(&self, mut idx: usize) -> usize { + let mut out = 0; + + let shape = self.shape.as_ref(); + let strides = self.strides.as_ref(); + + for (dim, stride) in shape.iter().zip(strides.iter()).rev() { + out += (idx % dim) * stride; + idx /= dim; + } + + out + } + #[inline(always)] pub(crate) fn next(&mut self) -> Option { match self.contiguous { diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 581f80f90..d3d4cc6ab 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -187,6 +187,7 @@ mod reshape_to; mod select_and_gather; mod sigmoid; mod sin; +mod slice; mod softmax; mod sqrt; mod square; @@ -236,6 +237,7 @@ pub use reshape_to::ReshapeTo; pub use select_and_gather::{GatherTo, SelectTo}; pub use sigmoid::sigmoid; pub use sin::sin; +pub use slice::slice; pub use softmax::softmax; pub use sqrt::sqrt; pub use square::square; diff --git a/src/tensor_ops/reshape_to/cuda_kernel.rs b/src/tensor_ops/reshape_to/cuda_kernel.rs index 14d350e4e..a09a5fcfa 100644 --- a/src/tensor_ops/reshape_to/cuda_kernel.rs +++ b/src/tensor_ops/reshape_to/cuda_kernel.rs @@ -11,6 +11,19 @@ trait HasCudaKernel { const FNS: &'static [&'static str]; } +macro_rules! has_kernels { + ($($dtype:ty),*) => { + $( + impl HasCudaKernel<$dtype> for Cuda { + const MOD: &'static str = concat!("slice_", stringify!($dtype)); + const FNS: &'static [&'static str] = &[concat!("slice_fwd_", stringify!($dtype))]; + } + )* + } +} + +has_kernels!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, bool); + impl HasCudaKernel for Cuda { const MOD: &'static str = "reshape_f32"; const FNS: &'static [&'static str] = &["reshape_fwd_f32", "reshape_bwd_f32"]; diff --git a/src/tensor_ops/reshape_to/reshape.cu b/src/tensor_ops/reshape_to/reshape.cu index ca3517f87..9a4380d6d 100644 --- a/src/tensor_ops/reshape_to/reshape.cu +++ b/src/tensor_ops/reshape_to/reshape.cu @@ -1,3 +1,4 @@ +#include #include "cuda_utils.cuh" template @@ -46,8 +47,8 @@ __device__ void reshape_bwd( atomicAdd(grad_inp + inp_i, grad_out[out_i]); } -#define RESHAPE(TYPENAME, FWD, BWD) \ -extern "C" __global__ void FWD( \ +#define RESHAPE_FWD(TYPENAME, FN) \ +extern "C" __global__ void FN( \ const size_t numel, \ const TYPENAME *inp, \ const size_t inp_num_dims, \ @@ -59,7 +60,11 @@ extern "C" __global__ void FWD( \ const size_t *out_strides \ ) { \ reshape_fwd(numel, inp, inp_num_dims, inp_dims, inp_strides, out, out_num_dims, out_dims, out_strides); \ -} \ +} + +#define RESHAPE(TYPENAME, FWD, BWD) \ +RESHAPE_FWD(TYPENAME, FWD) \ +\ extern "C" __global__ void BWD( \ const size_t numel, \ TYPENAME *grad_inp, \ @@ -76,3 +81,14 @@ extern "C" __global__ void BWD( \ RESHAPE(float, reshape_fwd_f32, reshape_bwd_f32); RESHAPE(double, reshape_fwd_f64, reshape_bwd_f64); +RESHAPE_FWD(uint8_t, reshape_fwd_u8); +RESHAPE_FWD(uint16_t, reshape_fwd_u16); +RESHAPE_FWD(uint32_t, reshape_fwd_u32); +RESHAPE_FWD(uint64_t, reshape_fwd_u64); +RESHAPE_FWD(uintptr_t, reshape_fwd_usize); +RESHAPE_FWD(int8_t, reshape_fwd_i8); +RESHAPE_FWD(int16_t, reshape_fwd_i16); +RESHAPE_FWD(int32_t, reshape_fwd_i32); +RESHAPE_FWD(int64_t, reshape_fwd_i64); +RESHAPE_FWD(intptr_t, reshape_fwd_isize); +RESHAPE_FWD(bool, reshape_fwd_bool); diff --git a/src/tensor_ops/slice/cpu_kernel.rs b/src/tensor_ops/slice/cpu_kernel.rs new file mode 100644 index 000000000..c9f5cdb6e --- /dev/null +++ b/src/tensor_ops/slice/cpu_kernel.rs @@ -0,0 +1,52 @@ +use crate::prelude::cpu::{LendingIterator, NdIndex}; + +use super::*; + +impl SliceKernel for Cpu { + fn forward, Slice>( + &self, + inp: &Tensor, + slice: &Slice, + ) -> Result, Self::Err> { + let dst = inp.shape.slice(slice).unwrap(); + let mut out = self.try_zeros_like(&dst)?; + + let mut inp_idx = NdIndex::new(dst, inp.strides); + let mut out_iter = out.iter_mut(); + + let start_idx = NdIndex::new(inp.shape, inp.strides) + .get_strided_index(inp.shape.first_idx_in_slice(slice)); + let view = &inp.data[start_idx..]; + + println!("{} {}", start_idx, inp.shape.first_idx_in_slice(slice)); + + while let Some((inp_i, o)) = inp_idx.next().zip(out_iter.next()) { + *o = view[inp_i]; + } + + Ok(out) + } + + fn backward, Slice>( + &self, + inp: &Tensor, + grad_inp: &mut Vec, + grad_out: &Vec, + slice: &Slice, + ) -> Result<(), Self::Err> { + let dst = inp.shape.slice(slice).unwrap(); + + let mut inp_idx = NdIndex::new(dst, inp.strides); + let mut out_iter = grad_out.iter(); + + let start_idx = NdIndex::new(inp.shape, inp.strides) + .get_strided_index(inp.shape.first_idx_in_slice(slice)); + let view = &mut grad_inp[start_idx..]; + + while let Some((inp_i, o)) = inp_idx.next().zip(out_iter.next()) { + view[inp_i] = *o; + } + + Ok(()) + } +} diff --git a/src/tensor_ops/slice/cuda_kernel.rs b/src/tensor_ops/slice/cuda_kernel.rs new file mode 100644 index 000000000..eec114f15 --- /dev/null +++ b/src/tensor_ops/slice/cuda_kernel.rs @@ -0,0 +1,113 @@ +use crate::{ + prelude::cpu::NdIndex, + shapes::*, + tensor::{Cuda, Tensor}, +}; +use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; + +const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/slice.ptx")); + +pub(crate) trait HasCudaKernel { + const MOD: &'static str; + const FNS: &'static [&'static str]; +} + +macro_rules! has_kernels { + ($($dtype:ty),*) => { + $( + impl HasCudaKernel<$dtype> for Cuda { + const MOD: &'static str = concat!("slice_", stringify!($dtype)); + const FNS: &'static [&'static str] = &[concat!("slice_fwd_", stringify!($dtype))]; + } + )* + } +} + +has_kernels!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, bool); + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "slice_f32"; + const FNS: &'static [&'static str] = &["slice_fwd_f32", "slice_bwd_f32"]; +} + +impl HasCudaKernel for Cuda { + const MOD: &'static str = "slice_f64"; + const FNS: &'static [&'static str] = &["slice_fwd_f64", "slice_bwd_f64"]; +} + +impl super::SliceKernel for Cuda +where + Self: HasCudaKernel, +{ + fn forward, Slice>( + &self, + inp: &Tensor, + slice: &Slice, + ) -> Result, Self::Err> { + if !self.dev.has_func(Self::MOD, Self::FNS[0]) { + self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?; + } + + let dst = inp.shape.slice(slice).unwrap(); + let strides = inp.strides; + let numel = dst.num_elements(); + + let start_idx = NdIndex::new(inp.shape, inp.strides) + .get_strided_index(inp.shape.first_idx_in_slice(slice)); + + let mut storage = unsafe { self.dev.alloc::(numel) }?; + + let dims: CudaSlice = self.dev.htod_copy(dst.concrete().into())?; + let strides: CudaSlice = self.dev.htod_copy(strides.into())?; + + let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap(); + let cfg = LaunchConfig::for_num_elems(numel as u32); + let params = ( + numel, // const size_t numel, + Src::NUM_DIMS, // const size_t num_dims, + &dims, // const size_t *dims, + &strides, // const size_t *strides, + start_idx, // const size_t offset, + inp.data.as_ref(), // const T *inp, + &mut storage, // T *out + ); + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(self.build_tensor(dst, dst.strides(), storage)) + } + + fn backward, Slice>( + &self, + inp: &Tensor, + grad_inp: &mut CudaSlice, + grad_out: &CudaSlice, + slice: &Slice, + ) -> Result<(), Self::Err> { + if !self.dev.has_func(Self::MOD, Self::FNS[1]) { + self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?; + } + + let dst = inp.shape.slice(slice).unwrap(); + let strides = inp.strides; + let numel = dst.num_elements(); + + let start_idx = NdIndex::new(inp.shape, inp.strides) + .get_strided_index(inp.shape.first_idx_in_slice(slice)); + + let dims: CudaSlice = self.dev.htod_copy(dst.concrete().into())?; + let strides: CudaSlice = self.dev.htod_copy(strides.into())?; + + let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap(); + let cfg = LaunchConfig::for_num_elems(numel as u32); + let params = ( + numel, // const size_t numel, + Src::NUM_DIMS, // const size_t num_dims, + &dims, // const size_t *dims, + &strides, // const size_t *strides, + start_idx, // const size_t offset, + grad_inp, // T *grad_inp, + grad_out, // const T *out + ); + unsafe { bwd_fn.launch(cfg, params) }?; + Ok(()) + } +} diff --git a/src/tensor_ops/slice/mod.rs b/src/tensor_ops/slice/mod.rs new file mode 100644 index 000000000..6870df068 --- /dev/null +++ b/src/tensor_ops/slice/mod.rs @@ -0,0 +1,179 @@ +use crate::{shapes::*, tensor::*}; + +mod cpu_kernel; +#[cfg(feature = "cuda")] +mod cuda_kernel; + +pub trait SliceKernel: DeviceStorage { + fn forward, Slice>( + &self, + inp: &Tensor, + slice: &Slice, + ) -> Result, Self::Err>; + + fn backward, Slice>( + &self, + inp: &Tensor, + grad_inp: &mut Self::Vec, + grad_out: &Self::Vec, + slice: &Slice, + ) -> Result<(), Self::Err>; +} + +/// Slices all dimensions of a tensor, with the starting and ending indices of each dimension +/// determined by a tuple of ranges. +/// +/// Slices are specified as tuples of ranges defined with the `..` and `..=` operators. All +/// sliced dimensions are changed to be of type usize except those sliced with `..` +/// ([std::ops::RangeFull]), whose types are not modified. +/// +/// Example: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev = Cpu::default(); +/// let a = dev.tensor([ +/// [1., 2.], +/// [3., 4.], +/// ]); +/// +/// // Slice the first row to get a 1x2 tensor +/// let b: Tensor, _, _> = a.clone().slice((0..1, 0..2)).realize().unwrap(); +/// assert_eq!(b.array(), [[1., 2.]]); +/// +/// // Slice the last column to get a 2x1 tensor +/// let c: Tensor, _, _> = a.clone().slice((0..2, 1..)).realize().unwrap(); +/// assert_eq!(c.array(), [[2.], [4.]]); +/// ``` +pub fn slice, E: Unit, D: SliceKernel, T: Tape, Slice: 'static>( + tensor: Tensor, + slice: Slice, +) -> Tensor { + tensor.slice(slice) +} + +impl, T: Tape> Tensor { + /// Fallible version of [Tensor::slice] + pub fn try_slice(self, slice: Slice) -> Result, D::Err> + where + S: SliceShape, + Slice: 'static, + { + let (inp, mut tape) = self.split_tape(); + let out = inp.device.forward(&inp, &slice)?; + let phantom_out = out.clone(); + + tape.try_alloc_grad(&inp)?; + tape.try_alloc_grad(&out)?; + tape.add_backward_op(move |grads| { + let (grad_inp, grad_out) = grads.mut_and_ref(&inp, &phantom_out); + inp.device.backward(&inp, grad_inp, grad_out, &slice) + }); + Ok(out.put_tape(tape)) + } + + /// Calls [slice]. + pub fn slice(self, slice: Slice) -> Tensor + where + S: SliceShape, + Slice: 'static, + { + self.try_slice(slice).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor_ops::*; + use crate::tests::TestDevice; + + #[test] + fn test_slice() { + let dev = TestDevice::default(); + let a = dev.tensor([ + [1., 2., 3., 4.], + [5., 6., 7., 8.], + [9., 10., 11., 12.], + [13., 14., 15., 16.], + ]); + + let b: Tensor, _, _> = a.clone().slice((2.., 2..)).realize().unwrap(); + assert_eq!(b.array(), [[11., 12.], [15., 16.]]); + + let b: Tensor, _, _> = a.clone().slice((1..3, 1..3)).realize().unwrap(); + assert_eq!(b.array(), [[6., 7.], [10., 11.]]); + + let b: Tensor, _, _> = a.clone().slice((..1, 1..4)).realize().unwrap(); + assert_eq!(b.array(), [[2., 3., 4.]]); + + let b: Tensor, _, _> = a.clone().slice((1..3, ..3)).realize().unwrap(); + assert_eq!(b.array(), [[5., 6., 7.], [9., 10., 11.]]); + + let b: Tensor, _, _> = a.clone().slice((1..=2, 1..=3)).realize().unwrap(); + assert_eq!(b.array(), [[6., 7., 8.], [10., 11., 12.]]); + + let b: Tensor, _, _> = a.clone().slice((0..=1, 2..=3)).realize().unwrap(); + assert_eq!(b.array(), [[3., 4.], [7., 8.]]); + + let b: Tensor, _, _> = a.clone().slice((1.., ..2)).realize().unwrap(); + assert_eq!(b.array(), [[5., 6.], [9., 10.], [13., 14.]]); + + let b: Tensor, _, _> = a.clone().slice((..2, 2..)).realize().unwrap(); + assert_eq!(b.array(), [[3., 4.], [7., 8.]]); + } + + #[test] + fn test_slice_broadcast_top() { + let dev = TestDevice::default(); + let a: Tensor, _, _> = dev.tensor([1., 2., 3., 4.]).broadcast(); + + let b: Tensor, _, _> = a.clone().slice((..3, ..)).realize().unwrap(); + assert_eq!(b.array(), [[1., 2., 3., 4.]; 3]); + + let b: Tensor, _, _> = a.clone().slice((.., 1..3)).realize().unwrap(); + assert_eq!(b.array(), [[2., 3.]; 5]); + + let b: Tensor, _, _> = a.clone().slice((1..3, 1..3)).realize().unwrap(); + assert_eq!(b.array(), [[2., 3.], [2., 3.]]); + + let b: Tensor, _, _> = a.clone().slice((1..4, 1..4)).realize().unwrap(); + assert_eq!(b.array(), [[2., 3., 4.]; 3]); + } + + #[test] + fn test_slice_broadcast_bottom() { + let dev = TestDevice::default(); + let a: Tensor, _, _> = dev.tensor([1., 2., 3., 4.]).broadcast(); + + let b: Tensor, _, _> = a.clone().slice((1..3, ..)).realize().unwrap(); + assert_eq!(b.array(), [[2.; 5], [3.; 5]]); + + let b: Tensor, _, _> = a.clone().slice((.., 1..3)).realize().unwrap(); + assert_eq!(b.array(), [[1., 1.], [2., 2.], [3., 3.], [4., 4.]]); + + let b: Tensor, _, _> = a.clone().slice((1..3, 3..)).realize().unwrap(); + assert_eq!(b.array(), [[2., 2.], [3., 3.]]); + + let b: Tensor, _, _> = a.clone().slice((..2, 1..3)).realize().unwrap(); + assert_eq!(b.array(), [[1., 1.], [2., 2.]]); + } + + #[test] + fn test_slice_backward() { + let dev = TestDevice::default(); + let a = dev.tensor([ + [1., 2., 3., 4.], + [5., 6., 7., 8.], + [9., 10., 11., 12.], + [13., 14., 15., 16.], + ]); + + let b: Tensor, _, _, _> = a.leaky_trace().slice((2.., 2..)).realize().unwrap(); + assert_eq!(b.array(), [[11., 12.], [15., 16.]]); + let g = b.square().sum().backward(); + assert_eq!( + g.get(&a).array(), + [[0.; 4], [0.; 4], [0., 0., 22., 24.], [0., 0., 30., 32.]] + ); + } +} diff --git a/src/tensor_ops/slice/slice.cu b/src/tensor_ops/slice/slice.cu new file mode 100644 index 000000000..5a520cc25 --- /dev/null +++ b/src/tensor_ops/slice/slice.cu @@ -0,0 +1,83 @@ +#include +#include "cuda_utils.cuh" + +template +__device__ void slice_fwd( + const size_t numel, + const size_t num_dims, + const size_t *dims, + const size_t *strides, + const size_t offset, + const T *inp, + T *out +) { + unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; + if (out_i >= numel) { + return; + } + + unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); + out[out_i] = inp[inp_i]; +} + +template +__device__ void slice_bwd( + const size_t numel, + const size_t num_dims, + const size_t *dims, + const size_t *strides, + const size_t offset, + T *grad_inp, + const T *out +) { + unsigned int out_i = blockIdx.x * blockDim.x + threadIdx.x; + if (out_i >= numel) { + return; + } + + unsigned int inp_i = offset + get_strided_index(out_i, num_dims, dims, strides); + // TODO (maybe): use chunk_sum to speed this up + atomicAdd(grad_inp + inp_i, out[out_i]); +} + +#define SLICE_FWD(TYPENAME, FN) \ +extern "C" __global__ void FN( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims, \ + const size_t *strides, \ + const size_t offset, \ + const TYPENAME *inp, \ + TYPENAME *out \ +) { \ + slice_fwd(numel, num_dims, dims, strides, offset, inp, out); \ +} + +#define SLICE(TYPENAME, FWD, BWD) \ +SLICE_FWD(TYPENAME, FWD) \ +\ +extern "C" __global__ void BWD( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims, \ + const size_t *strides, \ + const size_t offset, \ + TYPENAME *grad_inp, \ + const TYPENAME *grad_out \ +) { \ + slice_bwd(numel, num_dims, dims, strides, offset, grad_inp, grad_out); \ +} + +SLICE(float, slice_fwd_f32, slice_bwd_f32); +SLICE(double, slice_fwd_f64, slice_bwd_f64); +SLICE_FWD(uint8_t, slice_fwd_u8); +SLICE_FWD(uint16_t, slice_fwd_u16); +SLICE_FWD(uint32_t, slice_fwd_u32); +SLICE_FWD(uint64_t, slice_fwd_u64); +SLICE_FWD(uintptr_t, slice_fwd_usize); +SLICE_FWD(int8_t, slice_fwd_i8); +SLICE_FWD(int16_t, slice_fwd_i16); +SLICE_FWD(int32_t, slice_fwd_i32); +SLICE_FWD(int64_t, slice_fwd_i64); +SLICE_FWD(intptr_t, slice_fwd_isize); +SLICE_FWD(bool, slice_fwd_bool); diff --git a/src/tensor_ops/utilities/device.rs b/src/tensor_ops/utilities/device.rs index d623159c2..6911cb5b3 100644 --- a/src/tensor_ops/utilities/device.rs +++ b/src/tensor_ops/utilities/device.rs @@ -35,6 +35,7 @@ pub trait Device: + super::super::select_and_gather::ReplaceDimKernel + super::super::select_and_gather::RemoveDimKernel + super::super::choose::ChooseKernel + + super::super::slice::SliceKernel // matmuls + super::super::matmul::VecMatKernel