Skip to content

Commit

Permalink
disable counting by default
Browse files Browse the repository at this point in the history
  • Loading branch information
pacak committed Mar 13, 2024
1 parent ac13db9 commit 16df1c5
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 11 deletions.
14 changes: 14 additions & 0 deletions crates/argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,20 @@ where
pub fn take_prev_residuals(&mut self) -> Option<R> {
self.prev_residuals.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{IterState, State};
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>
Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/linearprogramstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub struct LinearProgramState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -151,6 +153,20 @@ impl<P, F> LinearProgramState<P, F> {
self.cost = cost;
self
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, LinearProgramState};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for LinearProgramState<P, F>
Expand Down Expand Up @@ -206,6 +222,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -504,7 +521,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -521,9 +538,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/populationstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub struct PopulationState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -430,6 +432,20 @@ where
pub fn take_population(&mut self) -> Option<Vec<P>> {
self.population.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, PopulationState};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for PopulationState<P, F>
Expand Down Expand Up @@ -484,6 +500,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -783,7 +800,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -800,9 +817,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ mod tests {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(13))
.configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
Expand Down
14 changes: 12 additions & 2 deletions crates/argmin/src/solver/linesearch/backtracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down Expand Up @@ -689,7 +694,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down

0 comments on commit 16df1c5

Please sign in to comment.