diff --git a/crates/argmin/src/core/executor.rs b/crates/argmin/src/core/executor.rs
index f64bb652b..ccc7c7b0e 100644
--- a/crates/argmin/src/core/executor.rs
+++ b/crates/argmin/src/core/executor.rs
@@ -69,7 +69,7 @@ where
checkpoint: None,
timeout: None,
ctrlc: true,
- timer: true,
+ timer: false,
}
}
@@ -383,9 +383,9 @@ where
self
}
- /// Enables or disables timing of individual iterations (default: enabled).
+ /// Enables or disables timing of individual iterations (default: false).
///
- /// Setting this to false will silently be ignored in case a timeout is set.
+ /// In case a timeout is set, this will automatically be set to true.
///
/// # Example
///
@@ -768,7 +768,7 @@ mod tests {
let problem = TestProblem::new();
let timeout = std::time::Duration::from_secs(2);
- let executor = Executor::new(problem, solver);
+ let executor = Executor::new(problem, solver).timer(true);
assert!(executor.timer);
assert!(executor.timeout.is_none());
diff --git a/crates/argmin/src/core/state/iterstate.rs b/crates/argmin/src/core/state/iterstate.rs
index 40a25d924..500f5a011 100644
--- a/crates/argmin/src/core/state/iterstate.rs
+++ b/crates/argmin/src/core/state/iterstate.rs
@@ -81,6 +81,8 @@ pub struct IterState
{
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap,
+ /// Update evaluation counts?
+ pub counting_enabled: bool,
/// Time required so far
pub time: Option,
/// Status of optimization execution
@@ -969,6 +971,20 @@ where
pub fn take_prev_residuals(&mut self) -> Option {
self.prev_residuals.take()
}
+
+ /// Overrides state of counting function executions (default: false)
+ /// ```
+ /// # use argmin::core::{IterState, State};
+ /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new();
+ /// # assert!(!state.counting_enabled);
+ /// let state = state.counting(true);
+ /// # assert!(state.counting_enabled);
+ /// ```
+ #[must_use]
+ pub fn counting(mut self, mode: bool) -> Self {
+ self.counting_enabled = mode;
+ self
+ }
}
impl State for IterState
@@ -1039,6 +1055,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
+ counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
@@ -1338,7 +1355,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, IterState, State, ArgminFloat};
- /// # let mut state: IterState, (), (), (), (), f64> = IterState::new();
+ /// # let mut state: IterState, (), (), (), (), f64> = IterState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
@@ -1355,9 +1372,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts(&mut self, problem: &Problem) {
- for (k, &v) in problem.counts.iter() {
- let count = self.counts.entry(k.to_string()).or_insert(0);
- *count = v
+ if self.counting_enabled {
+ for (k, &v) in problem.counts.iter() {
+ let count = self.counts.entry(k.to_string()).or_insert(0);
+ *count = v
+ }
}
}
diff --git a/crates/argmin/src/core/state/linearprogramstate.rs b/crates/argmin/src/core/state/linearprogramstate.rs
index 1642b873f..670f51a93 100644
--- a/crates/argmin/src/core/state/linearprogramstate.rs
+++ b/crates/argmin/src/core/state/linearprogramstate.rs
@@ -56,6 +56,8 @@ pub struct LinearProgramState {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap,
+ /// Update evaluation counts?
+ pub counting_enabled: bool,
/// Time required so far
pub time: Option,
/// Status of optimization execution
@@ -150,6 +152,20 @@ impl LinearProgramState
{
self.cost = cost;
self
}
+
+ /// Overrides state of counting function executions (default: false)
+ /// ```
+ /// # use argmin::core::{State, LinearProgramState};
+ /// # let mut state: LinearProgramState, f64> = LinearProgramState::new();
+ /// # assert!(!state.counting_enabled);
+ /// let state = state.counting(true);
+ /// # assert!(state.counting_enabled);
+ /// ```
+ #[must_use]
+ pub fn counting(mut self, mode: bool) -> Self {
+ self.counting_enabled = mode;
+ self
+ }
}
impl State for LinearProgramState
@@ -205,6 +221,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
+ counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
@@ -503,7 +520,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
- /// # let mut state: LinearProgramState, f64> = LinearProgramState::new();
+ /// # let mut state: LinearProgramState, f64> = LinearProgramState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
@@ -520,9 +537,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts(&mut self, problem: &Problem) {
- for (k, &v) in problem.counts.iter() {
- let count = self.counts.entry(k.to_string()).or_insert(0);
- *count = v
+ if self.counting_enabled {
+ for (k, &v) in problem.counts.iter() {
+ let count = self.counts.entry(k.to_string()).or_insert(0);
+ *count = v
+ }
}
}
diff --git a/crates/argmin/src/core/state/populationstate.rs b/crates/argmin/src/core/state/populationstate.rs
index b48efba91..7eda6af4e 100644
--- a/crates/argmin/src/core/state/populationstate.rs
+++ b/crates/argmin/src/core/state/populationstate.rs
@@ -58,6 +58,8 @@ pub struct PopulationState {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap,
+ /// Update evaluation counts?
+ pub counting_enabled: bool,
/// Time required so far
pub time: Option,
/// Status of optimization execution
@@ -429,6 +431,20 @@ where
pub fn take_population(&mut self) -> Option> {
self.population.take()
}
+
+ /// Overrides state of counting function executions (default: false)
+ /// ```
+ /// # use argmin::core::{State, PopulationState};
+ /// # let mut state: PopulationState, f64> = PopulationState::new();
+ /// # assert!(!state.counting_enabled);
+ /// let state = state.counting(true);
+ /// # assert!(state.counting_enabled);
+ /// ```
+ #[must_use]
+ pub fn counting(mut self, mode: bool) -> Self {
+ self.counting_enabled = mode;
+ self
+ }
}
impl State for PopulationState
@@ -483,6 +499,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
+ counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
@@ -782,7 +799,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
- /// # let mut state: PopulationState, f64> = PopulationState::new();
+ /// # let mut state: PopulationState, f64> = PopulationState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
@@ -799,9 +816,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts(&mut self, problem: &Problem) {
- for (k, &v) in problem.counts.iter() {
- let count = self.counts.entry(k.to_string()).or_insert(0);
- *count = v
+ if self.counting_enabled {
+ for (k, &v) in problem.counts.iter() {
+ let count = self.counts.entry(k.to_string()).or_insert(0);
+ *count = v
+ }
}
}
diff --git a/crates/argmin/src/solver/brent/brentopt.rs b/crates/argmin/src/solver/brent/brentopt.rs
index ade3e820e..96f8d56ec 100644
--- a/crates/argmin/src/solver/brent/brentopt.rs
+++ b/crates/argmin/src/solver/brent/brentopt.rs
@@ -231,7 +231,7 @@ mod tests {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
- .configure(|state| state.max_iters(13))
+ .configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
diff --git a/crates/argmin/src/solver/linesearch/backtracking.rs b/crates/argmin/src/solver/linesearch/backtracking.rs
index e5acbd92c..0b2e42c68 100644
--- a/crates/argmin/src/solver/linesearch/backtracking.rs
+++ b/crates/argmin/src/solver/linesearch/backtracking.rs
@@ -640,7 +640,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);
let data = Executor::new(prob, ls.clone())
- .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
+ .configure(|config| {
+ config
+ .counting(true)
+ .param(ls.init_param.clone().unwrap())
+ .max_iters(10)
+ })
.run();
assert!(data.is_ok());
@@ -689,7 +694,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);
let data = Executor::new(prob, ls.clone())
- .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
+ .configure(|config| {
+ config
+ .counting(true)
+ .param(ls.init_param.clone().unwrap())
+ .max_iters(10)
+ })
.run();
assert!(data.is_ok());
diff --git a/crates/argmin/src/tests.rs b/crates/argmin/src/tests.rs
index d7bc9fa74..c7370ec52 100644
--- a/crates/argmin/src/tests.rs
+++ b/crates/argmin/src/tests.rs
@@ -161,7 +161,12 @@ fn test_lbfgs_func_count() {
let linesearch = MoreThuenteLineSearch::new();
let solver = LBFGS::new(linesearch, 10);
let res = Executor::new(cost.clone(), solver)
- .configure(|config| config.param(cost.param_init.clone()).max_iters(100))
+ .configure(|config| {
+ config
+ .param(cost.param_init.clone())
+ .max_iters(100)
+ .counting(true)
+ })
.run()
.unwrap();