From 053984e2a2358769fe9665c8d20475de5d223ddf Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 20 Jul 2022 16:53:30 -0400 Subject: [PATCH] Changing dropout to use map_df_uses_fx --- src/tensor_ops/impl_dropout.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/tensor_ops/impl_dropout.rs b/src/tensor_ops/impl_dropout.rs index 85d816df9..94fb3e353 100644 --- a/src/tensor_ops/impl_dropout.rs +++ b/src/tensor_ops/impl_dropout.rs @@ -1,7 +1,6 @@ -use super::utils::move_tape_and_add_backward_op; use crate::prelude::*; use rand::Rng; -use rand_distr::{Distribution, Standard}; +use rand_distr::Standard; /// Randomly drops out elements from `t` with probability `p`, and multiplies all elements by `1 / (1 - p)`. /// @@ -15,17 +14,17 @@ pub fn dropout, R: Rng>(t: T, p: f32, rng: &mut R) -> T { } else { // `t` owns the tape in this branch, so apply dropout randomly. let rinvp = (1.0 - p).recip(); - let deriv = T::Device::filled(&mut |d| { - let val: f32 = Standard.sample(rng); - *d = if val < p { 0.0 } else { rinvp }; - }); - let mut result = T::NoTape::zeros(); - T::Device::addmul(result.mut_data(), t.data(), deriv.as_ref()); - - move_tape_and_add_backward_op(t, result, move |t, result, grads| { - let (t_grad, result_grad) = grads.mut_and_ref(&t, &result); - T::Device::addmul(t_grad, deriv.as_ref(), result_grad); - }) + map_df_uses_fx( + t, + move |x| { + if rng.sample::(Standard) < p { + 0.0 + } else { + x * rinvp + } + }, + move |fx| if fx > &0.0 { rinvp } else { 0.0 }, + ) } }