Skip to content

Commit

Permalink
Clone now keeps same id, removing Tensor::duplicate (#249)
Browse files Browse the repository at this point in the history
* Removing duplicate. Adding detach. Improving Clone semantics

* Removing detach
  • Loading branch information
coreylowman authored Oct 17, 2022
1 parent d79ca59 commit caf15b7
Show file tree
Hide file tree
Showing 20 changed files with 57 additions and 85 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Since all operations result in exactly 1 child, we can always move the gradient
*All of this together gives users unprecedented control/precision over what tensors are recorded on the gradient tape!*

One advanced use case requires that tensors be re-used multiple times in a computation graph.
This can be handled by duplicating the tensor, and manually moving the gradient tape around.
This can be handled by cloning the tensor, and manually moving the gradient tape around.
See [examples/12-multi-headed.rs](examples/12-multi-headed.rs) for an example.

### Type checked backward
Expand Down
2 changes: 1 addition & 1 deletion examples/01-tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() {
let a_data: &[[[f32; 4]; 3]; 2] = a.data();
println!("a={:?}", a_data);

// you can clone() a tensor (or duplicate()):
// you can clone() a tensor:
let a_copy = a.clone();
assert_eq!(a_copy.data(), a.data());
}
6 changes: 3 additions & 3 deletions examples/rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ fn main() {
let ratio = (log_prob_a - &old_log_prob_a).exp();

// because we need to re-use `ratio` a 2nd time, we need to do some tape manipulation here.
let r_ = ratio.duplicate();
let (surr1, tape) = (ratio * &advantage).split_tape();
let surr2 = (r_.put_tape(tape)).clamp(0.8, 1.2) * &advantage;
let (ratio, tape) = ratio.split_tape();
let (surr1, tape) = (ratio.clone().put_tape(tape) * &advantage).split_tape();
let surr2 = (ratio.put_tape(tape)).clamp(0.8, 1.2) * &advantage;

let ppo_loss = -(minimum(surr2, &surr1).mean());

Expand Down
3 changes: 1 addition & 2 deletions src/losses.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Standard loss functions such as [mse_loss()], [cross_entropy_with_logits_loss()], and more.
use crate::arrays::{AllAxes, HasArrayType, HasLastAxis};
use crate::tensor::Tensor;
use crate::tensor_ops::*;

/// [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error).
Expand Down Expand Up @@ -150,7 +149,7 @@ where
{
let probs = log_softmax::<_, <T::Array as HasLastAxis>::LastAxis>(logits);
let r = negate(mean::<_, AllAxes>(mul(
sub(probs, &ln(target_probs.duplicate())),
sub(probs, &ln(target_probs.clone())),
target_probs,
)));
mul_scalar(r, <T::Array as HasLastAxis>::SIZE as f32)
Expand Down
24 changes: 12 additions & 12 deletions src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ impl<const C: usize> BatchNorm2D<C> {
Tensor1D<C>: BroadcastTo<T, Axes>,
{
// statistics for normalizing
let std = (self.running_var.duplicate() + self.epsilon).sqrt();
let mean = self.running_mean.duplicate();
let std = (self.running_var.clone() + self.epsilon).sqrt();
let mean = self.running_mean.clone();

// normalize & affine
let x = sub(x, &mean.broadcast());
let x = div(x, &std.broadcast());
let x = mul(x, &self.scale.duplicate().broadcast());
add(x, &self.bias.duplicate().broadcast())
let x = mul(x, &self.scale.clone().broadcast());
add(x, &self.bias.clone().broadcast())
}

fn train_fwd<T, Axes>(&mut self, x: T) -> T
Expand All @@ -80,19 +80,19 @@ impl<const C: usize> BatchNorm2D<C> {
let (x, tape) = x.split_tape();

// compute statistics for updating running stats later - on tape
let (mean_t, tape): (Tensor1D<C>, _) = mean(x.duplicate().put_tape(tape)).split_tape();
let (var_t, tape): (Tensor1D<C>, _) = var(x.duplicate().put_tape(tape)).split_tape();
let (mean_t, tape): (Tensor1D<C>, _) = mean(x.clone().put_tape(tape)).split_tape();
let (var_t, tape): (Tensor1D<C>, _) = var(x.clone().put_tape(tape)).split_tape();

// update statistics since we are training - off tape
self.running_mean = add(
self.running_mean.duplicate() * (1.0 - self.momentum),
&(mean_t.duplicate() * self.momentum),
self.running_mean.clone() * (1.0 - self.momentum),
&(mean_t.clone() * self.momentum),
);
let n = <T::Array as HasAxes<Axes>>::SIZE as f32;
self.running_var = add(
self.running_var.duplicate() * (1.0 - self.momentum),
self.running_var.clone() * (1.0 - self.momentum),
// NOTE: uses unbiased variance in running estimate
&(var_t.duplicate() * (self.momentum * n / (n - 1.0))),
&(var_t.clone() * (self.momentum * n / (n - 1.0))),
);

// statistics for normalizing - on tape
Expand All @@ -102,9 +102,9 @@ impl<const C: usize> BatchNorm2D<C> {
let (mean, tape) = mean.split_tape();

// record broadcast of scale & bias - on tape
let scale: T = self.scale.duplicate().put_tape(tape).broadcast();
let scale: T = self.scale.clone().put_tape(tape).broadcast();
let (scale, tape) = scale.split_tape();
let bias: T = self.bias.duplicate().put_tape(tape).broadcast();
let bias: T = self.bias.clone().put_tape(tape).broadcast();
let (bias, tape) = bias.split_tape();

// normalize & affine - on tape
Expand Down
4 changes: 2 additions & 2 deletions src/nn/generalized_residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ where
let (x, tape) = x.split_tape();

// do R(x) on the tape
let r_x = self.r.forward(x.duplicate().put_tape(tape));
let r_x = self.r.forward(x.clone().put_tape(tape));
let (r_x, tape) = r_x.split_tape();

// do F(x) on the tape
Expand All @@ -77,7 +77,7 @@ where
let (x, tape) = x.split_tape();

// do R(x) on the tape
let r_x = self.r.forward_mut(x.duplicate().put_tape(tape));
let r_x = self.r.forward_mut(x.clone().put_tape(tape));
let (r_x, tape) = r_x.split_tape();

// do F(x) on the tape
Expand Down
8 changes: 4 additions & 4 deletions src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ impl<H: Tape, const B: usize, const M: usize> Module<Tensor2D<B, M, H>> for Laye
/// 3. [add()] with [Self::beta]
fn forward(&self, x: Tensor2D<B, M, H>) -> Self::Output {
let (x, tape) = x.normalize::<Axis<1>>(self.epsilon).split_tape();
let g: Tensor2D<B, M, H> = self.gamma.duplicate().put_tape(tape).broadcast();
let g: Tensor2D<B, M, H> = self.gamma.clone().put_tape(tape).broadcast();
let (x, tape) = mul(g, &x).split_tape();
let b = self.beta.duplicate().put_tape(tape).broadcast();
let b = self.beta.clone().put_tape(tape).broadcast();
add(b, &x)
}
}
Expand All @@ -95,9 +95,9 @@ impl<H: Tape, const B: usize, const S: usize, const M: usize> Module<Tensor3D<B,
/// 3. [add()] with [Self::beta]
fn forward(&self, x: Tensor3D<B, S, M, H>) -> Self::Output {
let (x, tape) = x.normalize::<Axis<2>>(self.epsilon).split_tape();
let g: Tensor3D<B, S, M, H> = self.gamma.duplicate().put_tape(tape).broadcast();
let g: Tensor3D<B, S, M, H> = self.gamma.clone().put_tape(tape).broadcast();
let (x, tape) = mul(g, &x).split_tape();
let b = self.beta.duplicate().put_tape(tape).broadcast();
let b = self.beta.clone().put_tape(tape).broadcast();
add(b, &x)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<const B: usize, const I: usize, const O: usize, H: Tape> Module<Tensor2D<B,
/// Batched 2d forward using [matmul()] and [add()]
fn forward(&self, x: Tensor2D<B, I, H>) -> Self::Output {
let (x, tape) = matmul_transpose(x, &self.weight).split_tape();
add(self.bias.duplicate().put_tape(tape).broadcast(), &x)
add(self.bias.clone().put_tape(tape).broadcast(), &x)
}
}

Expand All @@ -79,7 +79,7 @@ impl<const B: usize, const S: usize, const I: usize, const O: usize, H: Tape>
/// Batched 3d forward using [matmul()] and [add()]
fn forward(&self, x: Tensor3D<B, S, I, H>) -> Self::Output {
let (x, tape) = matmul_transpose(x, &self.weight).split_tape();
add(self.bias.duplicate().put_tape(tape).broadcast(), &x)
add(self.bias.clone().put_tape(tape).broadcast(), &x)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/nn/residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ impl<T: Tensor<Dtype = f32>, F: Module<T, Output = T>> Module<T> for Residual<F>
type Output = F::Output;
fn forward(&self, x: T) -> Self::Output {
let (x, tape) = x.split_tape();
add(self.0.forward(x.duplicate().put_tape(tape)), &x)
add(self.0.forward(x.clone().put_tape(tape)), &x)
}
}

impl<T: Tensor<Dtype = f32>, F: ModuleMut<T, Output = T>> ModuleMut<T> for Residual<F> {
type Output = F::Output;
fn forward_mut(&mut self, x: T) -> Self::Output {
let (x, tape) = x.split_tape();
add(self.0.forward_mut(x.duplicate().put_tape(tape)), &x)
add(self.0.forward_mut(x.clone().put_tape(tape)), &x)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/nn/split_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ where
fn forward(&self, x: Input) -> Self::Output {
let (x, tape) = x.split_tape();
let ($($heads, )+ $tail) = &self.0;
$(let ($heads, tape) = $heads.forward(x.duplicate().put_tape(tape)).split_tape();)+
$(let ($heads, tape) = $heads.forward(x.clone().put_tape(tape)).split_tape();)+
let $tail = $tail.forward(x.put_tape(tape));
(
$($heads,)+
Expand All @@ -77,7 +77,7 @@ where
fn forward_mut(&mut self, x: Input) -> Self::Output {
let (x, tape) = x.split_tape();
let ($($heads, )+ $tail) = &mut self.0;
$(let ($heads, tape) = $heads.forward_mut(x.duplicate().put_tape(tape)).split_tape();)+
$(let ($heads, tape) = $heads.forward_mut(x.clone().put_tape(tape)).split_tape();)+
let $tail = $tail.forward_mut(x.put_tape(tape));
(
$($heads,)+
Expand Down
19 changes: 9 additions & 10 deletions src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ impl<const M: usize, const H: usize, const F: usize, const L: usize> CanUpdateWi
impl<const M: usize, const H: usize, const F: usize, const L: usize, Tgt, Mem> Module<(Tgt, Mem)>
for TransformerDecoder<M, H, F, L>
where
Mem: Tensor<NoTape = Mem>,
Mem: Tensor<NoTape = Mem> + Clone,
TransformerDecoderBlock<M, H, F>: Module<(Tgt, Mem), Output = Tgt>,
{
type Output = Tgt;

fn forward(&self, (mut x, mem): (Tgt, Mem)) -> Self::Output {
for block in self.0.modules.iter() {
x = block.forward((x, mem.duplicate()));
x = block.forward((x, mem.clone()));
}
x
}
Expand Down Expand Up @@ -125,7 +125,7 @@ impl<const M: usize, const H: usize, const F: usize, Tgt, Mem> Module<(Tgt, Mem)
for TransformerDecoderBlock<M, H, F>
where
Tgt: Tensor<Dtype = f32>,
Mem: Tensor<Dtype = f32, NoTape = Mem>,
Mem: Tensor<Dtype = f32, NoTape = Mem> + Clone,
MultiHeadAttention<M, H>: Module<(Tgt, Tgt::NoTape, Tgt::NoTape), Output = Tgt>
+ Module<(Tgt, Mem, Mem), Output = Tgt>,
LayerNorm1D<M>: Module<Tgt, Output = Tgt>,
Expand All @@ -135,16 +135,15 @@ where

fn forward(&self, (tgt, mem): (Tgt, Mem)) -> Self::Output {
let (tgt, tape) = tgt.split_tape();
let x = self.self_attn.forward((
tgt.duplicate().put_tape(tape),
tgt.duplicate(),
tgt.duplicate(),
));
let x = self
.self_attn
.forward((tgt.clone().put_tape(tape), tgt.clone(), tgt.clone()));
let x = add(x, &tgt);
let x = self.norm1.forward(x);

let x_ = x.duplicate();
let x = self.mh_attn.forward((x, mem.duplicate(), mem));
let (x, tape) = x.split_tape();
let x_ = x.clone();
let x = self.mh_attn.forward((x.put_tape(tape), mem.clone(), mem));
let x = add(x, &x_);
let x = self.norm2.forward(x);
let x = self.ff.forward(x);
Expand Down
8 changes: 3 additions & 5 deletions src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@ where

fn forward(&self, src: Src) -> Self::Output {
let (src, tape) = src.split_tape();
let x = self.self_attn.forward((
src.duplicate().put_tape(tape),
src.duplicate(),
src.duplicate(),
));
let x = self
.self_attn
.forward((src.clone().put_tape(tape), src.clone(), src.clone()));
let x = add(x, &src);
let x = self.norm1.forward(x);
let x = self.ff.forward(x);
Expand Down
26 changes: 4 additions & 22 deletions src/tensor/impl_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::arrays::HasArrayType;
use crate::gradients::{CanUpdateWithGradients, NoneTape, Tape};
use crate::prelude::*;
use crate::unique_id::{unique_id, HasUniqueId};
use crate::unique_id::HasUniqueId;

/// The main tensor trait. A tensor consists of mainly 1. an array, 2. a device, 3. a unique id.
pub trait Tensor:
Expand All @@ -21,9 +21,6 @@ pub trait Tensor:

/// Removes whatever Tape this tensor has and returns itself without a tape.
fn split_tape(self) -> (Self::NoTape, Self::Tape);

/// Clones the data and id of this tensor and returns something with [NoneTape].
fn duplicate(&self) -> Self::NoTape;
}

macro_rules! tensor_impl {
Expand All @@ -38,21 +35,13 @@ impl<$(const $Vs: usize, )* H: Tape> Tensor for $struct<$($Vs, )* H> {
self.tape,
)
}

fn duplicate(&self) -> Self::NoTape {
Self::NoTape {
id: self.id,
data: self.data.clone(),
tape: Default::default(),
}
}
}

impl<$(const $Vs: usize, )* H: Clone> Clone for $struct<$($Vs, )* H> {
/// Clones the underlying data and tape. **Creates a new `id`.**
/// Clones the underlying id, data, and tape
fn clone(&self) -> Self {
Self {
id: unique_id(),
id: self.id,
data: self.data.clone(),
tape: self.tape.clone(),
}
Expand All @@ -71,18 +60,11 @@ tensor_impl!(Tensor4D, [M, N, O, P]);
mod tests {
use super::*;

#[test]
fn test_ids_with_duplicate() {
let t1: Tensor1D<32> = TensorCreator::zeros();
let t2: Tensor1D<32, NoneTape> = t1.duplicate();
assert_eq!(t1.id, t2.id);
}

#[test]
fn test_ids_with_clone() {
let t1: Tensor1D<32> = TensorCreator::zeros();
let t2: Tensor1D<32, NoneTape> = t1.clone();
assert_ne!(t1.id, t2.id);
assert_eq!(t1.id, t2.id);
}

#[test]
Expand Down
5 changes: 2 additions & 3 deletions src/tensor/impl_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ use super::*;
use crate::gradients::{NoneTape, OwnedTape};

/// Transforms a [NoneTape] tensor to an [OwnedTape] tensor by cloning.
/// Clones `t` using [Tensor::duplicate()] (to preserve id), and then
/// inserts [OwnedTape] as the tape.
/// Clones `t` using, and then inserts [OwnedTape] as the tape.
///
/// See [traced()] for version that takes ownership of `t`.
pub fn trace<T: Tensor<Tape = OwnedTape>>(t: &T::NoTape) -> T {
traced(t.duplicate())
traced(t.clone())
}

/// Transforms a [NoneTape] tensor to an [OwnedTape] by directly inserting a
Expand Down
6 changes: 0 additions & 6 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,6 @@
//! let t_clone: Tensor1D<5, OwnedTape> = t.trace(); // copies t
//! let t: Tensor1D<5, OwnedTape> = t.traced(); // takes ownership of t
//! ```
//!
//! # Cloning/copying
//!
//! There are two primary methods for copying a tensor
//! 1. [Clone] is implemented for tensors without a tape. **NOTE** that the unique id is modified when a tensor is cloned
//! 2. [Tensor::duplicate()] is implemented for all tensors, it copies the [crate::unique_id::UniqueId], and returns a tensor with no tape.
mod impl_default;
mod impl_has_array;
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/structs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! We use [std::sync::Arc] instead of [Box] here to:
//! 1. reduce allocations when tensors are duplicated/cloned.
//! 1. reduce allocations when tensors are cloned.
//! 2. make sharing tensors and things that contain tensors across threads easy
//!
//! See the following for more discussion:
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/impl_normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ where
T::Array: HasAxes<Axes>,
{
let (t, tape) = t.split_tape();
let (std, tape) = stddev(t.duplicate().put_tape(tape), epsilon)
let (std, tape) = stddev(t.clone().put_tape(tape), epsilon)
.broadcast()
.split_tape();
let (mean, tape) = mean(t.duplicate().put_tape(tape)).broadcast().split_tape();
let (mean, tape) = mean(t.clone().put_tape(tape)).broadcast().split_tape();
let centered = sub(t.put_tape(tape), &mean);
div(centered, &std)
}
Expand Down
4 changes: 1 addition & 3 deletions src/tensor_ops/impl_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ pub fn logsumexp<T: Reduce<Axes>, Axes>(mut t: T) -> T::Reduced {
/// ```
pub fn log_softmax<T: Reduce<Axes>, Axes>(t: T) -> T {
let (t, tape) = t.split_tape();
let (lse, tape) = logsumexp(t.duplicate().put_tape(tape))
.broadcast()
.split_tape();
let (lse, tape) = logsumexp(t.clone().put_tape(tape)).broadcast().split_tape();
sub(t.put_tape(tape), &lse)
}

Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/impl_stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ where
{
let num_elements: f32 = <T::Array as HasAxes<Axes>>::SIZE as f32;
let (t, tape) = t.split_tape();
let mean = mean(t.duplicate().put_tape(tape)).broadcast();
let mean = mean(t.clone().put_tape(tape)).broadcast();
div_scalar(sum(square(sub(mean, &t))), num_elements)
}

Expand Down
Loading

0 comments on commit caf15b7

Please sign in to comment.