Skip to content

Commit

Permalink
#80 Adding nn::Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 14, 2022
1 parent 770e2ea commit ea913f2
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/nn/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ activation_impls!(Tanh, tanh, #[doc="Unit struct that impls [Module] as calling
activation_impls!(Square, square, #[doc="Unit struct that impls [Module] as calling [square()] on `input`."]);
activation_impls!(Sqrt, sqrt, #[doc="Unit struct that impls [Module] as calling [sqrt()] on `input`."]);
activation_impls!(Abs, abs, #[doc="Unit struct that impls [Module] as calling [abs()] on `input`."]);
activation_impls!(Softmax, softmax, #[doc="Unit struct that impls [Module] as calling [softmax()] on `input`."]);

#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -119,4 +120,22 @@ mod tests {
let r2 = abs(t);
assert_eq!(r1.data(), r2.data());
}

#[test]
fn test_softmax() {
let t = Tensor0D::new(0.0);
let r1 = Softmax.forward(t.clone());
let r2 = softmax(t);
assert_eq!(r1.data(), r2.data());

let t = Tensor1D::new([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r1 = Softmax.forward(t.clone());
let r2 = softmax(t);
assert_eq!(r1.data(), r2.data());

let t = Tensor2D::new([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]);
let r1 = Softmax.forward(t.clone());
let r2 = softmax(t);
assert_eq!(r1.data(), r2.data());
}
}

0 comments on commit ea913f2

Please sign in to comment.