Skip to content

Commit

Permalink
Add tests to BrentRoot
Browse files Browse the repository at this point in the history
Added tests to BrentRoot and a new error message for negative tolerance
  • Loading branch information
GermanHeim authored and stefan-k committed Oct 27, 2024
1 parent 2551ec2 commit c7673ef
Showing 1 changed file with 128 additions and 1 deletion.
129 changes: 128 additions & 1 deletion crates/argmin/src/solver/brent/brentroot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub enum BrentRootError {
/// f(min) and f(max) must have different signs
#[error("BrentRoot error: f(min) and f(max) must have different signs.")]
WrongSign,
// tol must be positive
#[error("BrentRoot error: tol must be positive.")]
NegativeTol,
}

/// # Brent's method
Expand Down Expand Up @@ -95,6 +98,9 @@ where
if self.fa * self.fb > float!(0.0) {
return Err(BrentRootError::WrongSign.into());
}
if self.tol < F::zero() {
return Err(BrentRootError::NegativeTol.into());
}
self.fc = self.fb;
Ok((state.param(self.b).cost(self.fb.abs()), None))
}
Expand Down Expand Up @@ -183,6 +189,127 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Executor;
use approx::assert_relative_eq;

#[derive(Clone)]
struct Quadratic {}

impl CostFunction for Quadratic {
type Param = f64;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
Ok(param.powi(2) - 1.0) // x^2 - 1
}
}

#[test]
fn test_brent_negative_tol() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = -1e-6;

let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization fails and we get the correct error message
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"BrentRoot error: tol must be positive."
);
}

#[test]
fn test_brent_invalid_range() {
let min: f64 = 2.0;
let max: f64 = 3.0;
let tol: f64 = 1e-6;

let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization fails and we get the correct error message
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"BrentRoot error: f(min) and f(max) must have different signs."
);
}

test_trait_impl!(brent, BrentRoot<f64>);
#[test]
fn test_brent_valid_range() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = 1e-6;

let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization is successful
assert!(result.is_ok());
}

#[test]
fn test_brent_find_root() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = 1e-6;
let init_param: f64 = 1.5;

let solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
let problem: Quadratic = Quadratic {};

let res = Executor::new(problem, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();

// Check if the result is close to the real root
assert_relative_eq!(res.state.best_param.unwrap(), 1.0, epsilon = tol);
}

#[test]
fn test_brent_symmetry() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = 1e-6;
let init_param: f64 = 1.5;

let problem: Quadratic = Quadratic {};

// First run with [min, max] interval
let solver1: BrentRoot<f64> = BrentRoot::new(min, max, tol);
let res1 = Executor::new(problem.clone(), solver1)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();

// Second run with [max, min] interval (swapped inputs)
let solver2: BrentRoot<f64> = BrentRoot::new(max, min, tol);
let res2 = Executor::new(problem, solver2)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();

// Check if the results are the same
assert_relative_eq!(
res1.state.param.unwrap(),
res2.state.param.unwrap(),
epsilon = tol,
);

// Check if the number of iterations is the same
assert_eq!(res1.state.get_iter(), res2.state.get_iter());
}
}

0 comments on commit c7673ef

Please sign in to comment.