Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove phantom #282

Merged
merged 2 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions src/tensor/impl_phantom.rs

This file was deleted.

8 changes: 1 addition & 7 deletions src/tensor/impl_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@ use crate::unique_id::{internal, unique_id, HasUniqueId};

/// The main tensor trait. A tensor consists of mainly 1. an array, 2. a device, 3. a unique id.
pub trait Tensor:
HasArrayType
+ HasArrayData
+ HasDevice
+ CanUpdateWithGradients
+ HasUniqueId
+ IntoPhantom
+ internal::ResetId
HasArrayType + HasArrayData + HasDevice + CanUpdateWithGradients + HasUniqueId + internal::ResetId
{
/// The [Tape] this tensor owns.
type Tape: Tape;
Expand Down
2 changes: 0 additions & 2 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ mod impl_default;
mod impl_has_array;
mod impl_has_device;
mod impl_has_unique_id;
mod impl_phantom;
mod impl_put_tape;
mod impl_randomize;
mod impl_tensor;
Expand All @@ -88,7 +87,6 @@ pub use impl_default::*;
pub use impl_has_array::*;
pub use impl_has_device::*;
pub use impl_has_unique_id::*;
pub use impl_phantom::*;
pub use impl_put_tape::*;
pub use impl_randomize::*;
pub use impl_tensor::*;
Expand Down
12 changes: 6 additions & 6 deletions src/tensor_ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ impl<const C: usize, const H: usize, const W: usize, T: Tape> Tensor3D<C, H, W,

let f = filters.clone();
let (x, mut tape) = self.split_tape();
let phf = filters.phantom();
let phb = bias.phantom();
let phr = result.phantom();
let phf = filters.clone();
let phb = bias.clone();
let phr = result.clone();
tape.add_backward_op(move |grads| {
let (fg, bg, ig, rg) = grads.muts_and_ref(&phf, &phb, &x, &phr);
<Cpu as DeviceConv2D<S, P>>::conv_backward(x.data(), f.data(), rg, ig, fg, bg);
Expand Down Expand Up @@ -51,9 +51,9 @@ impl<const B: usize, const C: usize, const H: usize, const W: usize, T: Tape>
let f = filters.clone();

let (x, mut tape) = self.split_tape();
let phf = filters.phantom();
let phb = bias.phantom();
let phr = result.phantom();
let phf = filters.clone();
let phb = bias.clone();
let phr = result.clone();
tape.add_backward_op(move |grads| {
let (fg, bg, ig, r_grad) = grads.muts_and_ref(&phf, &phb, &x, &phr);
let f = f.data();
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/pool2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<const B: usize, const C: usize, const H: usize, const W: usize, T: Tape>
Cpu::pool_forward(x_i, r_i);
}
let (x, mut tape) = self.split_tape();
let r = result.phantom();
let r = result.clone();
tape.add_backward_op(move |grads| {
let (xg, rg) = grads.mut_and_ref(&x, &r);
for ((x_i, rg_i), xg_i) in x.data().iter().zip(rg.iter()).zip(xg.iter_mut()) {
Expand Down
10 changes: 5 additions & 5 deletions src/tensor_ops/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
let (t, mut tape) = t.split_tape();
let mut result = t.clone(); // inc t's reference count
result.reset_id(); // ensure there are two differet nodes in the graph
let phantom_result = result.phantom();
let phantom_result = result.clone();
tape.add_backward_op(move |grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &phantom_result);
T::Device::foreach_mrr(t_grad, t.data(), result_grad, &mut |g, fx, r| {
Expand Down Expand Up @@ -112,9 +112,9 @@ pub(super) fn move_tape_and_add_backward_op<Inp, Out, F>(
where
Inp: Tensor,
Out: Tensor<Tape = Inp::Tape>,
F: 'static + FnMut(Inp::NoTape, PhantomTensor<Out::NoTape>, &mut Gradients),
F: 'static + FnMut(Inp::NoTape, Out::NoTape, &mut Gradients),
{
let phantom_out = out.phantom();
let phantom_out = out.clone();
let (t, mut tape) = inp.split_tape();
tape.add_backward_op(move |grads| f(t, phantom_out, grads));
out.put_tape(tape)
Expand All @@ -132,9 +132,9 @@ where
Rhs: Tensor,
Out: Tensor<Tape = Lhs::Tape>,
Lhs::Tape: Merge<Rhs::Tape>,
F: 'static + FnMut(Lhs::NoTape, Rhs::NoTape, PhantomTensor<Out::NoTape>, &mut Gradients),
F: 'static + FnMut(Lhs::NoTape, Rhs::NoTape, Out::NoTape, &mut Gradients),
{
let phantom_out = out.phantom();
let phantom_out = out.clone();
let (lhs, lhs_tape) = lhs.split_tape();
let (rhs, rhs_tape) = rhs.split_tape();
let mut tape = lhs_tape.merge(rhs_tape);
Expand Down