Skip to content

Commit

Permalink
#69 adding map_df_uses_fx
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 20, 2022
1 parent 977b452 commit 6a61e75
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions src/tensor_ops/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::ops::Neg;
/// assert_eq!(r.data(), &[2.0, 0.0, -5.0]);
/// ```
pub fn negate<T: Tensor<Dtype = f32>>(t: T) -> T {
map(t, |x| -x, |_| -1.0)
map_df_uses_fx(t, |x| -x, |_| -1.0)
}

/// `max(0, t)`. Computes [Rectified Linear Unit (ReLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)).
Expand All @@ -32,7 +32,7 @@ pub fn negate<T: Tensor<Dtype = f32>>(t: T) -> T {
/// let r2 = t.relu();
/// ```
pub fn relu<T: Tensor<Dtype = f32>>(t: T) -> T {
map(t, |x| x.max(0.0), |x| if x > &0.0 { 1.0 } else { 0.0 })
map_df_uses_fx(t, |x| x.max(0.0), |fx| if fx > &0.0 { 1.0 } else { 0.0 })
}

/// `t^2`
Expand Down Expand Up @@ -70,7 +70,7 @@ pub fn square<T: Tensor<Dtype = f32>>(t: T) -> T {
/// let r2 = t.sqrt();
/// ```
pub fn sqrt<T: Tensor<Dtype = f32>>(t: T) -> T {
map(t, |x| x.sqrt(), |x| 0.5 * x.sqrt().recip())
map_df_uses_fx(t, |x| x.sqrt(), |fx| 0.5 * fx.recip())
}

/// `tanh(t)`. Computes the [Hyperbolic Tangent (Tanh)](https://en.wikipedia.org/wiki/Hyperbolic_functions).
Expand All @@ -89,7 +89,7 @@ pub fn sqrt<T: Tensor<Dtype = f32>>(t: T) -> T {
/// let r2 = t.tanh();
/// ```
pub fn tanh<T: Tensor<Dtype = f32>>(t: T) -> T {
map(t, |x| x.tanh(), |x| 1.0 - x.tanh().powi(2))
map_df_uses_fx(t, |x| x.tanh(), |fx| 1.0 - fx.powi(2))
}

/// `1 / (1 + exp(-t))`. Computes [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function).
Expand All @@ -112,12 +112,7 @@ pub fn sigmoid<T: Tensor<Dtype = f32>>(t: T) -> T {
(1.0 + x.neg().exp()).recip()
}

fn df(x: &f32) -> f32 {
let s = f(x);
s * (1.0 - s)
}

map(t, f, df)
map_df_uses_fx(t, f, |fx| fx * (1.0 - fx))
}

/// `sin(t)`. Computes the [sine function](https://en.wikipedia.org/wiki/Sine_and_cosine).
Expand Down Expand Up @@ -193,7 +188,7 @@ pub fn ln<T: Tensor<Dtype = f32>>(t: T) -> T {
/// let r2 = t.exp();
/// ```
pub fn exp<T: Tensor<Dtype = f32>>(t: T) -> T {
map(t, |x| x.exp(), |x| x.exp())
map_df_uses_fx(t, |x| x.exp(), |fx| *fx)
}

/// `|t|`. Computes the [absolute value (abs)](https://en.wikipedia.org/wiki/Absolute_value).
Expand Down Expand Up @@ -242,6 +237,25 @@ where
})
}

/// Same as [map()], but calls `df` with the result of `f(x)`. This can potentially remove an allocation.
pub fn map_df_uses_fx<T: Tensor<Dtype = f32>, F, Df>(mut t: T, mut f: F, mut df: Df) -> T
where
F: FnMut(&f32) -> f32,
Df: 'static + FnMut(&f32) -> f32,
{
T::Device::foreach_m(t.mut_data(), &mut |x| *x = f(x)); // clones if there is more than 1 reference to t
let (t, mut tape) = t.split_tape();
let result = t.clone(); // will always a new reference to t, not start a new one
let phantom_result = result.phantom();
tape.add_backward_op(move |grads| {
let (t_grad, result_grad) = grads.mut_and_ref(&t, &phantom_result);
T::Device::foreach_mrr(t_grad, t.data(), result_grad, &mut |g, fx, r| {
*g += df(fx) * r;
});
});
result.put_tape(tape)
}

macro_rules! activation_impl {
($func_name:ident, #[$docstring:meta]) => {
#[$docstring]
Expand Down

0 comments on commit 6a61e75

Please sign in to comment.