Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use Cow when appropriate #688

Merged
merged 5 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/tensor_ops/pow/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
tensor::*,
tensor_ops::{cuda_kernels::cuda_unary, ops::UnaryKernel},
};
use std::borrow::Cow;

unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp<f32> {}
unsafe impl cudarc::driver::DeviceRepr for super::PowfKernelOp<f64> {}
Expand All @@ -24,7 +25,7 @@ where
fn forward<S: Shape>(
&self,
op: super::PowiKernelOp,
inp: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
inp: Cow<Tensor<S, E, Self>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, nice 🔥

) -> Result<Tensor<S, E, Self>, Self::Err> {
self.forward(super::PowfKernelOp(E::from_i32(op.0).unwrap()), inp)
}
Expand Down
23 changes: 15 additions & 8 deletions src/tensor_ops/utilities/cpu_kernels.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use super::ops::{BinaryKernel, UnaryKernel};
use crate::{
shapes::{Dtype, Shape},
Expand Down Expand Up @@ -48,10 +50,10 @@ impl<E: Dtype, Op: UnaryDerivative<E>> UnaryKernel<Op, E> for Cpu {
fn forward<S: Shape>(
&self,
op: Op,
inp: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
inp: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
let mut out = match inp {
Ok(inp) => {
Cow::Borrowed(inp) => {
// allocate a new data buffer
Tensor {
id: unique_id(),
Expand All @@ -62,7 +64,7 @@ impl<E: Dtype, Op: UnaryDerivative<E>> UnaryKernel<Op, E> for Cpu {
tape: Default::default(),
}
}
Err(mut inp) => {
Cow::Owned(mut inp) => {
// re-use the data buffer
inp.id = unique_id();
inp
Expand Down Expand Up @@ -111,11 +113,11 @@ impl<E: Dtype, Op: BinaryDerivative<E>> BinaryKernel<Op, E> for Cpu {
fn forward<S: Shape>(
&self,
op: Op,
lhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
rhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
lhs: Cow<Tensor<S, E, Self>>,
rhs: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
match (lhs, rhs) {
(Ok(lhs), Ok(rhs)) => {
(Cow::Borrowed(lhs), Cow::Borrowed(rhs)) => {
let mut out = self.try_zeros_like(&lhs.shape)?;
let mut lhs_iter = lhs.iter();
let mut rhs_iter = rhs.iter();
Expand All @@ -126,7 +128,7 @@ impl<E: Dtype, Op: BinaryDerivative<E>> BinaryKernel<Op, E> for Cpu {
}
Ok(out)
}
(Err(mut lhs), Err(mut rhs)) => {
(Cow::Owned(mut lhs), Cow::Owned(mut rhs)) => {
let lhs_valid = lhs.strides == lhs.shape.strides();
let rhs_valid = rhs.strides == rhs.shape.strides();
if lhs_valid || rhs_valid {
Expand All @@ -148,7 +150,12 @@ impl<E: Dtype, Op: BinaryDerivative<E>> BinaryKernel<Op, E> for Cpu {
Ok(lhs)
}
} else {
<Self as BinaryKernel<Op, E>>::forward(self, op, Ok(&lhs), Ok(&rhs))
<Self as BinaryKernel<Op, E>>::forward(
self,
op,
Cow::Borrowed(&lhs),
Cow::Borrowed(&rhs),
)
}
}
_ => unreachable!(),
Expand Down
28 changes: 14 additions & 14 deletions src/tensor_ops/utilities/cuda_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
tensor_ops::ops::{BinaryKernel, UnaryKernel},
};
use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync};
use std::{sync::Arc, vec::Vec};
use std::{borrow::Cow, sync::Arc, vec::Vec};

pub trait UnaryOpCudaKernel<E> {
const DF_USES_FX: bool;
Expand Down Expand Up @@ -66,7 +66,7 @@ impl<E: Dtype, K: UnaryOpCudaKernel<E> + DeviceRepr> UnaryKernel<K, E> for Cuda
fn forward<S: Shape>(
&self,
op: K,
inp: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
inp: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
if !self.dev.has_func(K::MODULE_NAME, K::FWD_FN_NAME) {
self.dev
Expand All @@ -76,7 +76,7 @@ impl<E: Dtype, K: UnaryOpCudaKernel<E> + DeviceRepr> UnaryKernel<K, E> for Cuda
let fwd_fn = self.dev.get_func(K::MODULE_NAME, K::FWD_FN_NAME).unwrap();

match inp {
Ok(inp) => {
Cow::Borrowed(inp) => {
let numel = inp.data.len();
let mut storage = unsafe { self.dev.alloc::<E>(numel) }?;

Expand All @@ -93,7 +93,7 @@ impl<E: Dtype, K: UnaryOpCudaKernel<E> + DeviceRepr> UnaryKernel<K, E> for Cuda
tape: Default::default(),
})
}
Err(mut inp) => {
Cow::Owned(mut inp) => {
inp.id = unique_id();
let numel = inp.data.len();
let cfg = launch_cfg(numel as u32);
Expand Down Expand Up @@ -225,8 +225,8 @@ impl<E: Dtype, K: BinaryOpCudaKernel<E> + DeviceRepr + Clone> BinaryKernel<K, E>
fn forward<S: Shape>(
&self,
op: K,
lhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
rhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
lhs: Cow<Tensor<S, E, Self>>,
rhs: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
if !self.dev.has_func(K::MODULE_NAME, K::FWD_FN_NAME) {
self.dev
Expand All @@ -235,20 +235,20 @@ impl<E: Dtype, K: BinaryOpCudaKernel<E> + DeviceRepr + Clone> BinaryKernel<K, E>
let fwd_fn = self.dev.get_func(K::MODULE_NAME, K::FWD_FN_NAME).unwrap();

let shape = match &lhs {
Ok(lhs) => lhs.shape,
Err(lhs) => lhs.shape,
Cow::Borrowed(lhs) => lhs.shape,
Cow::Owned(lhs) => lhs.shape,
};
let strides = shape.strides();
let numel = shape.num_elements();
let cfg = launch_cfg(numel as u32);

let lhs_strides = match &lhs {
Ok(lhs) => lhs.strides,
Err(lhs) => lhs.strides,
Cow::Borrowed(lhs) => lhs.strides,
Cow::Owned(lhs) => lhs.strides,
};
let rhs_strides = match &rhs {
Ok(rhs) => rhs.strides,
Err(rhs) => rhs.strides,
Cow::Borrowed(rhs) => rhs.strides,
Cow::Owned(rhs) => rhs.strides,
};

let mut info: Vec<usize> = Vec::with_capacity(3 * S::NUM_DIMS);
Expand All @@ -258,7 +258,7 @@ impl<E: Dtype, K: BinaryOpCudaKernel<E> + DeviceRepr + Clone> BinaryKernel<K, E>
let info = self.dev.htod_copy(info)?;

match (lhs, rhs) {
(Ok(lhs), Ok(rhs)) => {
(Cow::Borrowed(lhs), Cow::Borrowed(rhs)) => {
let mut storage = unsafe { self.dev.alloc::<E>(numel) }?;
let params = (
op,
Expand All @@ -279,7 +279,7 @@ impl<E: Dtype, K: BinaryOpCudaKernel<E> + DeviceRepr + Clone> BinaryKernel<K, E>
tape: Default::default(),
})
}
(Err(mut lhs), Err(mut rhs)) => {
(Cow::Owned(mut lhs), Cow::Owned(mut rhs)) => {
let lhs_valid = lhs.strides == lhs.shape.strides();
let rhs_valid = rhs.strides == rhs.shape.strides();
if lhs_valid || rhs_valid {
Expand Down
21 changes: 13 additions & 8 deletions src/tensor_ops/utilities/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use crate::{
shapes::{Dtype, HasShape, Shape},
tensor::{DeviceStorage, GhostTensor, Merge, PutTape, SplitTape, Tape, Tensor},
};
use std::borrow::Cow;

pub trait UnaryKernel<Op, E: Dtype>: DeviceStorage {
const BACKWARD_WITHOUT_INP: bool;
const BACKWARD_WITHOUT_DATA: bool;
fn forward<S: Shape>(
&self,
op: Op,
inp: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
inp: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err>;
fn backward<S: Shape>(
&self,
Expand All @@ -26,8 +27,8 @@ pub trait BinaryKernel<Op, E: Dtype>: DeviceStorage {
fn forward<S: Shape>(
&self,
op: Op,
lhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
rhs: Result<&Tensor<S, E, Self>, Tensor<S, E, Self>>,
lhs: Cow<Tensor<S, E, Self>>,
rhs: Cow<Tensor<S, E, Self>>,
) -> Result<Tensor<S, E, Self>, Self::Err>;
fn backward<S: Shape>(
&self,
Expand All @@ -54,7 +55,7 @@ pub(crate) fn try_unary_op<
let inp_ghost = inp.ghost();
let dev = inp.device.clone();
if !T::OWNS_TAPE || D::BACKWARD_WITHOUT_DATA {
let out = inp_ghost.dev.forward(op.clone(), Err(inp))?;
let out = inp_ghost.dev.forward(op.clone(), Cow::Owned(inp))?;
let out_ghost = out.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&inp_ghost)?;
Expand All @@ -64,7 +65,7 @@ pub(crate) fn try_unary_op<
});
Ok(out.put_tape(tape))
} else if D::BACKWARD_WITHOUT_INP {
let out = inp_ghost.dev.forward(op.clone(), Err(inp))?;
let out = inp_ghost.dev.forward(op.clone(), Cow::Owned(inp))?;
let out_ghost = out.ghost();
let out_clone = out.clone();
tape.add_backward_op(move |grads| {
Expand All @@ -75,7 +76,7 @@ pub(crate) fn try_unary_op<
});
Ok(out.put_tape(tape))
} else {
let out = inp.device.forward(op.clone(), Ok(&inp))?;
let out = inp.device.forward(op.clone(), Cow::Borrowed(&inp))?;
let out_ghost = out.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&inp_ghost)?;
Expand Down Expand Up @@ -106,7 +107,9 @@ pub(crate) fn try_binary_op<
let rhs_ghost = rhs.ghost();
let mut tape = ltape.merge(rtape);
if !LhsTape::OWNS_TAPE || D::BACKWARD_WITHOUT_DATA {
let out = lhs_ghost.dev.forward(op, Err(lhs), Err(rhs))?;
let out = lhs_ghost
.dev
.forward(op, Cow::Owned(lhs), Cow::Owned(rhs))?;
let out_ghost = out.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&lhs_ghost)?;
Expand All @@ -125,7 +128,9 @@ pub(crate) fn try_binary_op<
});
Ok(out.put_tape(tape))
} else {
let out = lhs.device.forward(op, Ok(&lhs), Ok(&rhs))?;
let out = lhs
.device
.forward(op, Cow::Borrowed(&lhs), Cow::Borrowed(&rhs))?;
let out_ghost = out.ghost();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&lhs_ghost)?;
Expand Down