Skip to content

Commit

Permalink
Changing dropout to use map_df_uses_fx
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 20, 2022
1 parent da54efc commit 053984e
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/tensor_ops/impl_dropout.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::utils::move_tape_and_add_backward_op;
use crate::prelude::*;
use rand::Rng;
use rand_distr::{Distribution, Standard};
use rand_distr::Standard;

/// Randomly drops out elements from `t` with probability `p`, and multiplies all elements by `1 / (1 - p)`.
///
Expand All @@ -15,17 +14,17 @@ pub fn dropout<T: Tensor<Dtype = f32>, R: Rng>(t: T, p: f32, rng: &mut R) -> T {
} else {
// `t` owns the tape in this branch, so apply dropout randomly.
let rinvp = (1.0 - p).recip();
let deriv = T::Device::filled(&mut |d| {
let val: f32 = Standard.sample(rng);
*d = if val < p { 0.0 } else { rinvp };
});
let mut result = T::NoTape::zeros();
T::Device::addmul(result.mut_data(), t.data(), deriv.as_ref());

move_tape_and_add_backward_op(t, result, move |t, result, grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &result);
T::Device::addmul(t_grad, deriv.as_ref(), result_grad);
})
map_df_uses_fx(
t,
move |x| {
if rng.sample::<f32, Standard>(Standard) < p {
0.0
} else {
x * rinvp
}
},
move |fx| if fx > &0.0 { rinvp } else { 0.0 },
)
}
}

Expand Down

0 comments on commit 053984e

Please sign in to comment.