Skip to content

Commit

Permalink
Remove allocation in dropout implementation (#164)
Browse files Browse the repository at this point in the history
Using map in dropout implementation
  • Loading branch information
coreylowman authored Aug 22, 2022
1 parent afb475d commit 32f534b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/nn/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{cell::RefCell, ops::DerefMut};
/// let dropout: DropoutOneIn<2> = Default::default();
/// let t: Tensor2D<2, 5> = Tensor2D::ones();
/// let r = dropout.forward(t.trace());
/// assert_eq!(r.data(), &[[2.0, 2.0, 2.0, 2.0, 2.0], [0.0, 2.0, 2.0, 2.0, 0.0]]);
/// assert_eq!(r.data(), &[[2.0, 2.0, 2.0, 0.0, 0.0], [2.0, 2.0, 0.0, 0.0, 2.0]]);
/// ```
#[derive(Clone, Debug)]
pub struct DropoutOneIn<const N: usize> {
Expand Down
72 changes: 43 additions & 29 deletions src/tensor_ops/impl_dropout.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::utils::move_tape_and_add_backward_op;
use crate::prelude::*;
use rand::Rng;
use rand_distr::{Distribution, Standard};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::Standard;

/// Does nothing if no tape is in `t`. Zeros elements with probability `p` and scales all elements by `1 / (1 - p)`.
/// See [Tape::OWNS_TAPE].
Expand All @@ -23,24 +22,41 @@ use rand_distr::{Distribution, Standard};
/// let a = dropout(t.trace(), 0.5, &mut rng);
/// assert_eq!(a.data(), &[2.0, 4.0, 0.0, 8.0]);
/// ```
///
/// ### Implementation details:
///
/// To reduce memory usage, this function first samples a u64 seed from `rng`,
/// and then instantiates two identical [StdRng] with that seed. These rngs
/// are used in both the forward pass and backward pass to generate identical
/// random numbers, so the masking is the same for both.
pub fn dropout<T: Tensor<Dtype = f32>, R: Rng>(t: T, p: f32, rng: &mut R) -> T {
if !T::Tape::OWNS_TAPE {
// This is the branch where `t` doesn't own the tape, so we don't have to drop out anything.
t
} else {
// `t` owns the tape in this branch, so apply dropout randomly.
let rinvp = (1.0 - p).recip();
let deriv = T::Device::filled(&mut |d| {
let val: f32 = Standard.sample(rng);
*d = if val < p { 0.0 } else { rinvp };
});
let mut result = T::NoTape::zeros();
T::Device::addmul(result.mut_data(), t.data(), deriv.as_ref());

move_tape_and_add_backward_op(t, result, move |t, result, grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &result);
T::Device::addmul(t_grad, deriv.as_ref(), result_grad);
})
let seed: u64 = rng.gen();
let mut fwd_rng = StdRng::seed_from_u64(seed);
let mut bwd_rng = StdRng::seed_from_u64(seed);
map(
t,
move |x| {
let val: f32 = fwd_rng.sample(Standard);
if val < p {
0.0
} else {
x / (1.0 - p)
}
},
move |_| {
let val: f32 = bwd_rng.sample(Standard);
if val < p {
0.0
} else {
1.0 / (1.0 - p)
}
},
)
}
}

Expand All @@ -64,6 +80,7 @@ tensor_impl!(Tensor4D, [M, N, O, P]);
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::assert_close;
use rand::{prelude::StdRng, SeedableRng};

#[test]
Expand Down Expand Up @@ -99,25 +116,22 @@ mod tests {
let mut rng = StdRng::seed_from_u64(3);
let t = Tensor1D::new([0.0, 2.0, -3.0, -4.0, 0.0]);
let r = t.trace().dropout(0.5, &mut rng);
assert_eq!(r.data(), &[0.0, 0.0, -6.0, 0.0, 0.0]);
assert_eq!(r.data(), &[0.0, 0.0, 0.0, -8.0, 0.0]);
let gradients = r.mean().backward();
assert_eq!(gradients.ref_gradient(&t), &[0.4, 0.0, 0.4, 0.0, 0.0]);
assert_eq!(gradients.ref_gradient(&t), &[0.0, 0.0, 0.0, 0.4, 0.0]);
}

#[test]
fn test_dropout_2d() {
let mut rng = StdRng::seed_from_u64(0);
let t = Tensor2D::new([[0.05, 0.1, 0.2], [0.3, 0.4, 0.5]]);
let t = Tensor2D::new([[0.05, 0.1, -0.2], [0.3, -0.4, 0.5]]);
let r = t.trace().dropout(0.6, &mut rng);
assert_eq!(
r.data(),
&[[0.12500001, 0.25000003, 0.0], [0.7500001, 1.0000001, 0.0]]
);
assert_close(r.data(), &[[0.125, 0.25, -0.5], [0.0, 0.0, 1.25]]);
// NOTE: .exp() so we ensure result grad is used properly
let gradients = r.exp().mean().backward();
assert_eq!(
gradients.ref_gradient(&t),
&[[0.47214523, 0.5350107, 0.0], [0.88208354, 1.1326177, 0.0]]
&[[0.47214523, 0.5350107, 0.2527211], [0.0, 0.0, 1.4543099]]
);
}

Expand All @@ -129,21 +143,21 @@ mod tests {
assert_eq!(
r.data(),
&[
[[1.25, 1.25, 1.25], [1.25, 1.25, 0.0]],
[[1.25, 1.25, 1.25], [1.25, 1.25, 1.25]],
[[1.25, 1.25, 1.25], [1.25, 0.0, 1.25]],
[[0.0, 1.25, 1.25], [1.25, 1.25, 1.25]],
[[1.25, 1.25, 1.25], [1.25, 1.25, 0.0]]
[[1.25, 1.25, 0.0], [1.25, 0.0, 1.25]],
[[1.25, 1.25, 1.25], [1.25, 1.25, 1.25]]
]
);
let gradients = r.mean().backward();
const V: f32 = 0.052083336;
assert_eq!(
gradients.ref_gradient(&t),
&[
[[V, V, V], [V, V, 0.0]],
[[V, V, V], [V, V, V]],
[[V, V, V], [V, 0.0, V]],
[[0.0, V, V], [V, V, V]],
[[V, V, V], [V, V, 0.0]]
[[V, V, 0.0], [V, 0.0, V]],
[[V, V, V], [V, V, V]]
]
);
}
Expand Down

0 comments on commit 32f534b

Please sign in to comment.