Skip to content

Commit

Permalink
Adds better assertion macros for testing (#714)
Browse files Browse the repository at this point in the history
* Adds assert_aclose! macro for testing

* Renaming & removing more assert closes

* Improving error messages

* Fix failing doctests

* Removing assert_close()

* Fixes f64 tests
  • Loading branch information
coreylowman authored Apr 19, 2023
1 parent 7cde7d8 commit ef287c1
Show file tree
Hide file tree
Showing 64 changed files with 956 additions and 996 deletions.
104 changes: 96 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ pub(crate) mod tests {
pub trait AssertClose {
type Elem: std::fmt::Display + std::fmt::Debug + Copy;
const DEFAULT_TOLERANCE: Self::Elem;
fn get_default_tol(&self) -> Self::Elem {
Self::DEFAULT_TOLERANCE
}
fn get_far_pair(
&self,
rhs: &Self,
Expand Down Expand Up @@ -313,15 +316,100 @@ pub(crate) mod tests {
}
}

pub fn assert_close<T: AssertClose + std::fmt::Debug>(a: &T, b: &T) {
a.assert_close(b, T::DEFAULT_TOLERANCE);
pub trait NdMap {
type Elem;
type Mapped<O>;
fn ndmap<O, F: Copy + FnMut(Self::Elem) -> O>(self, f: F) -> Self::Mapped<O>;
}

impl NdMap for f32 {
type Elem = Self;
type Mapped<O> = O;
fn ndmap<O, F: Copy + FnMut(Self::Elem) -> O>(self, mut f: F) -> O {
f(self)
}
}

impl NdMap for f64 {
type Elem = Self;
type Mapped<O> = O;
fn ndmap<O, F: Copy + FnMut(Self::Elem) -> O>(self, mut f: F) -> O {
f(self)
}
}

impl<T: NdMap, const M: usize> NdMap for [T; M] {
type Elem = T::Elem;
type Mapped<O> = [T::Mapped<O>; M];
fn ndmap<O, F: Copy + FnMut(Self::Elem) -> O>(self, f: F) -> Self::Mapped<O> {
self.map(|t| t.ndmap(f))
}
}

pub fn assert_close_with_tolerance<T: AssertClose + std::fmt::Debug>(
a: &T,
b: &T,
tolerance: T::Elem,
) {
a.assert_close(b, tolerance);
macro_rules! assert_close_to_literal {
($Lhs:expr, $Rhs:expr) => {{
let lhs = $Lhs.array();
let tol = AssertClose::get_default_tol(&lhs);
let far_pair = AssertClose::get_far_pair(
&lhs,
&$Rhs.ndmap(|x| num_traits::FromPrimitive::from_f64(x).unwrap()),
tol,
);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
}};
($Lhs:expr, $Rhs:expr, $Tolerance:expr) => {{
let far_pair = $Lhs.array().get_far_pair(
&$Rhs.ndmap(|x| num_traits::FromPrimitive::from_f64(x).unwrap()),
num_traits::FromPrimitive::from_f64($Tolerance).unwrap(),
);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
}};
}
pub(crate) use assert_close_to_literal;

macro_rules! assert_close_to_tensor {
($Lhs:expr, $Rhs:expr) => {
let lhs = $Lhs.array();
let tol = AssertClose::get_default_tol(&lhs);
let far_pair = AssertClose::get_far_pair(&lhs, &$Rhs.array(), tol);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
};
($Lhs:expr, $Rhs:expr, $Tolerance:expr) => {{
let far_pair = $Lhs.array().get_far_pair(
&$Rhs.array(),
num_traits::FromPrimitive::from_f64($Tolerance).unwrap(),
);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
}};
}
pub(crate) use assert_close_to_tensor;

macro_rules! assert_close {
($Lhs:expr, $Rhs:expr) => {
let lhs = $Lhs;
let tol = AssertClose::get_default_tol(&lhs);
let far_pair = AssertClose::get_far_pair(&lhs, &$Rhs, tol);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
};
($Lhs:expr, $Rhs:expr, $Tolerance:expr) => {{
let far_pair = $Lhs.get_far_pair(
&$Rhs,
num_traits::FromPrimitive::from_f64($Tolerance).unwrap(),
);
if let Some((l, r)) = far_pair {
panic!("lhs != rhs | {l} != {r}");
}
}};
}
pub(crate) use assert_close;
}
98 changes: 49 additions & 49 deletions src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ mod tests {
let y: Tensor<_, TestDtype, _> =
dev.tensor([-0.90954804, -1.0193185, -0.39221755, 2.2524886, 1.3035554]);
let loss = mse_loss(x.leaky_trace(), y);
assert_close(&loss.array(), &1.0846305);
assert_close_to_literal!(loss, 1.0846305);
let g = loss.backward();
assert_close(
&g.get(&x).array(),
&[0.7128116, 0.31071725, -0.24555098, -0.43896183, 0.10037976],
assert_close_to_literal!(
g.get(&x),
[0.7128116, 0.31071725, -0.24555098, -0.43896183, 0.10037976]
);
}

Expand All @@ -152,9 +152,9 @@ mod tests {
let y: Tensor<_, TestDtype, _> =
dev.tensor([-0.90954804, -1.0193186, -0.39221755, 2.2524886, 1.3035554]);
let loss = mae_loss(x.leaky_trace(), y);
assert_close(&loss.array(), &0.9042107);
assert_close_to_literal!(loss, 0.9042107);
let g = loss.backward();
assert_eq!(g.get(&x).array(), [0.2, 0.2, -0.2, -0.2, 0.2]);
assert_close_to_literal!(g.get(&x), [0.2, 0.2, -0.2, -0.2, 0.2]);
}

#[test]
Expand All @@ -169,21 +169,21 @@ mod tests {
[0.15627657, 0.29779273, 0.10897867, 0.2879545, 0.14899758],
]);
let loss = cross_entropy_with_logits_loss(x.leaky_trace(), y.clone());
assert_close(&loss.array(), &1.9889611);
assert_close_to_literal!(loss, 1.9889611);
let g = loss.backward();
assert_close(
&g.get(&x).array(),
&[
assert_close_to_literal!(
g.get(&x),
[
[-0.0972354, 0.0515665, -0.09250933, 0.07864318, 0.05953507],
[0.0035581, 0.1792296, -0.0074167, -0.1233234, -0.0520476],
],
]
);
assert_close(
&g.get(&y).array(),
&[
assert_close_to_literal!(
g.get(&y),
[
[1.0454637, 0.6836907, 1.4958019, 0.70222294, 0.56051415],
[0.9057989, 0.21060522, 1.1814584, 1.5933538, 1.5516331],
],
]
);
}

Expand All @@ -198,7 +198,7 @@ mod tests {
targ[i] = 1.0;
let y = dev.tensor(targ);
let loss = cross_entropy_with_logits_loss(x.leaky_trace(), y.clone());
assert_close(&loss.array(), &losses[i]);
assert_close_to_literal!(loss, losses[i]);
}
}

Expand All @@ -220,17 +220,17 @@ mod tests {
[0.0166, 0.8512, 0.1322],
]);
let loss = kl_div_with_logits_loss(logits.leaky_trace(), targ);
assert_close(&loss.array(), &0.40656143);
assert_close_to_literal!(loss, 0.40656143);
let g = loss.backward();
assert_close(
&g.get(&logits).array(),
&[
assert_close_to_literal!(
g.get(&logits),
[
[-0.031813223, -0.044453412, 0.07626665],
[0.05489187, -0.04143352, -0.013458336],
[-0.037454266, 0.02207594, 0.015378334],
[-0.09656205, 0.013436668, 0.083125375],
[0.02881821, -0.10633193, 0.0775137],
],
]
);
}

Expand All @@ -248,26 +248,26 @@ mod tests {
[0.7026833, 0.5563793, 0.6429267],
]);
let loss = binary_cross_entropy_with_logits_loss(logit.leaky_trace(), prob.clone());
assert_close(&loss.array(), &0.7045728);
assert_close_to_literal!(loss, 0.7045728);

let g = loss.backward();

assert_close(
&g.get(&logit).array(),
assert_close_to_literal!(
g.get(&logit),
&[
[0.003761424, -0.054871976, 0.025817735],
[-0.0009343492, 0.0051718787, 0.0074731046],
[-0.047248676, -0.03401173, 0.0071035423],
],
]
);

assert_close(
&g.get(&prob).array(),
assert_close_to_literal!(
g.get(&prob),
&[
[0.04546672, 0.07451131, -0.10224107],
[0.18426175, -0.18865204, 0.16475087],
[0.10635218, 0.12190584, -0.097797275],
],
]
);
}

Expand All @@ -279,26 +279,26 @@ mod tests {
let targ: Tensor<_, TestDtype, _> = dev.tensor([[0.0, 0.5, 1.0]; 3]);

let loss = binary_cross_entropy_with_logits_loss(logit.leaky_trace(), targ.clone());
assert_close(&loss.array(), &33.479964);
assert_close_to_literal!(loss, 33.479964);

let g = loss.backward();

assert_close(
&g.get(&logit).array(),
assert_close_to_literal!(
g.get(&logit),
&[
[0.11111111, 0.055555556, 0.0],
[0.0, -0.055555556, -0.11111111],
[0.029882379, 0.0, -0.02988238],
],
]
);

assert_close(
&g.get(&targ).array(),
assert_close_to_literal!(
g.get(&targ),
&[
[-11.111112, -11.111112, -11.111112],
[11.111112, 11.111112, 11.111112],
[0.11111111, 0.0, -0.11111111],
],
]
);
}

Expand All @@ -317,24 +317,24 @@ mod tests {
]);

let loss = huber_loss(x.leaky_trace(), y.clone(), 0.5);
assert_close(&loss.array(), &0.24506615);
assert_close_to_literal!(loss, 0.24506615);

let g = loss.backward();
assert_close(
&g.get(&x).array(),
assert_close_to_literal!(
g.get(&x),
&[
[-0.016490579, 0.014802615, -0.033333335, -0.012523981, 0.0],
[0.033333335, -0.0099870805, -0.033333335, 0.033333335, 0.0],
[0.033333335, -0.033333335, -0.02631244, 0.033333335, 0.0],
],
]
);
assert_close(
&g.get(&y).array(),
assert_close_to_literal!(
g.get(&y),
&[
[0.016490579, -0.014802615, 0.033333335, 0.012523981, 0.0],
[-0.033333335, 0.0099870805, 0.033333335, -0.033333335, 0.0],
[-0.033333335, 0.033333335, 0.02631244, -0.033333335, 0.0],
],
]
);
}

Expand All @@ -353,24 +353,24 @@ mod tests {
]);

let loss = smooth_l1_loss(x.leaky_trace(), y.clone(), 0.5);
assert_close(&loss.array(), &0.4901323);
assert_close_to_literal!(loss, 0.4901323);

let g = loss.backward();
assert_close(
&g.get(&x).array(),
assert_close_to_literal!(
g.get(&x),
&[
[-0.032981157, 0.02960523, -0.06666667, -0.025047962, 0.0],
[0.06666667, -0.019974161, -0.06666667, 0.06666667, 0.0],
[0.06666667, -0.06666667, -0.05262488, 0.06666667, 0.0],
],
]
);
assert_close(
&g.get(&y).array(),
assert_close_to_literal!(
g.get(&y),
&[
[0.032981157, -0.02960523, 0.06666667, 0.025047962, 0.0],
[-0.06666667, 0.019974161, 0.06666667, -0.06666667, 0.0],
[-0.06666667, 0.06666667, 0.05262488, -0.06666667, 0.0],
],
]
);
}
}
Loading

0 comments on commit ef287c1

Please sign in to comment.