From 306297e608b025108f07d4a339fd8c890e2f793f Mon Sep 17 00:00:00 2001 From: Michael Baikov Date: Mon, 11 Mar 2024 17:52:42 -0400 Subject: [PATCH 1/2] Disable timing by default, enable when needed --- crates/argmin/src/core/executor.rs | 8 ++++---- crates/argmin/src/core/state/iterstate.rs | 13 +++++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/argmin/src/core/executor.rs b/crates/argmin/src/core/executor.rs index f64bb652b..ccc7c7b0e 100644 --- a/crates/argmin/src/core/executor.rs +++ b/crates/argmin/src/core/executor.rs @@ -69,7 +69,7 @@ where checkpoint: None, timeout: None, ctrlc: true, - timer: true, + timer: false, } } @@ -383,9 +383,9 @@ where self } - /// Enables or disables timing of individual iterations (default: enabled). + /// Enables or disables timing of individual iterations (default: false). /// - /// Setting this to false will silently be ignored in case a timeout is set. + /// In case a timeout is set, this will automatically be set to true. /// /// # Example /// @@ -768,7 +768,7 @@ mod tests { let problem = TestProblem::new(); let timeout = std::time::Duration::from_secs(2); - let executor = Executor::new(problem, solver); + let executor = Executor::new(problem, solver).timer(true); assert!(executor.timer); assert!(executor.timeout.is_none()); diff --git a/crates/argmin/src/core/state/iterstate.rs b/crates/argmin/src/core/state/iterstate.rs index 40a25d924..12074885d 100644 --- a/crates/argmin/src/core/state/iterstate.rs +++ b/crates/argmin/src/core/state/iterstate.rs @@ -81,6 +81,8 @@ pub struct IterState { pub max_iters: u64, /// Evaluation counts pub counts: HashMap, + /// Update evaluation counts? + pub counting_enabled: bool, /// Time required so far pub time: Option, /// Status of optimization execution @@ -1039,6 +1041,7 @@ where last_best_iter: 0, max_iters: std::u64::MAX, counts: HashMap::new(), + counting_enabled: false, time: Some(instant::Duration::new(0, 0)), termination_status: TerminationStatus::NotTerminated, } @@ -1338,7 +1341,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new().counting(true); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -1355,9 +1358,11 @@ where /// # assert_eq!(state.counts, hm); /// ``` fn func_counts(&mut self, problem: &Problem) { - for (k, &v) in problem.counts.iter() { - let count = self.counts.entry(k.to_string()).or_insert(0); - *count = v + if self.counting_enabled { + for (k, &v) in problem.counts.iter() { + let count = self.counts.entry(k.to_string()).or_insert(0); + *count = v + } } } From 79a274a60eef97e3e6da86771efeb6de296a52a6 Mon Sep 17 00:00:00 2001 From: Michael Baikov Date: Wed, 13 Mar 2024 15:36:06 -0400 Subject: [PATCH 2/2] Disable counting by default --- crates/argmin/src/core/state/iterstate.rs | 14 ++++++++++ .../src/core/state/linearprogramstate.rs | 27 ++++++++++++++++--- .../argmin/src/core/state/populationstate.rs | 27 ++++++++++++++++--- crates/argmin/src/solver/brent/brentopt.rs | 2 +- .../src/solver/linesearch/backtracking.rs | 14 ++++++++-- crates/argmin/src/tests.rs | 7 ++++- 6 files changed, 79 insertions(+), 12 deletions(-) diff --git a/crates/argmin/src/core/state/iterstate.rs b/crates/argmin/src/core/state/iterstate.rs index 12074885d..500f5a011 100644 --- a/crates/argmin/src/core/state/iterstate.rs +++ b/crates/argmin/src/core/state/iterstate.rs @@ -971,6 +971,20 @@ where pub fn take_prev_residuals(&mut self) -> Option { self.prev_residuals.take() } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{IterState, State}; + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for IterState diff --git a/crates/argmin/src/core/state/linearprogramstate.rs b/crates/argmin/src/core/state/linearprogramstate.rs index 1642b873f..670f51a93 100644 --- a/crates/argmin/src/core/state/linearprogramstate.rs +++ b/crates/argmin/src/core/state/linearprogramstate.rs @@ -56,6 +56,8 @@ pub struct LinearProgramState { pub max_iters: u64, /// Evaluation counts pub counts: HashMap, + /// Update evaluation counts? + pub counting_enabled: bool, /// Time required so far pub time: Option, /// Status of optimization execution @@ -150,6 +152,20 @@ impl LinearProgramState { self.cost = cost; self } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{State, LinearProgramState}; + /// # let mut state: LinearProgramState, f64> = LinearProgramState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for LinearProgramState @@ -205,6 +221,7 @@ where last_best_iter: 0, max_iters: std::u64::MAX, counts: HashMap::new(), + counting_enabled: false, time: Some(instant::Duration::new(0, 0)), termination_status: TerminationStatus::NotTerminated, } @@ -503,7 +520,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat}; - /// # let mut state: LinearProgramState, f64> = LinearProgramState::new(); + /// # let mut state: LinearProgramState, f64> = LinearProgramState::new().counting(true); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -520,9 +537,11 @@ where /// # assert_eq!(state.counts, hm); /// ``` fn func_counts(&mut self, problem: &Problem) { - for (k, &v) in problem.counts.iter() { - let count = self.counts.entry(k.to_string()).or_insert(0); - *count = v + if self.counting_enabled { + for (k, &v) in problem.counts.iter() { + let count = self.counts.entry(k.to_string()).or_insert(0); + *count = v + } } } diff --git a/crates/argmin/src/core/state/populationstate.rs b/crates/argmin/src/core/state/populationstate.rs index b48efba91..7eda6af4e 100644 --- a/crates/argmin/src/core/state/populationstate.rs +++ b/crates/argmin/src/core/state/populationstate.rs @@ -58,6 +58,8 @@ pub struct PopulationState { pub max_iters: u64, /// Evaluation counts pub counts: HashMap, + /// Update evaluation counts? + pub counting_enabled: bool, /// Time required so far pub time: Option, /// Status of optimization execution @@ -429,6 +431,20 @@ where pub fn take_population(&mut self) -> Option> { self.population.take() } + + /// Overrides state of counting function executions (default: false) + /// ``` + /// # use argmin::core::{State, PopulationState}; + /// # let mut state: PopulationState, f64> = PopulationState::new(); + /// # assert!(!state.counting_enabled); + /// let state = state.counting(true); + /// # assert!(state.counting_enabled); + /// ``` + #[must_use] + pub fn counting(mut self, mode: bool) -> Self { + self.counting_enabled = mode; + self + } } impl State for PopulationState @@ -483,6 +499,7 @@ where last_best_iter: 0, max_iters: std::u64::MAX, counts: HashMap::new(), + counting_enabled: false, time: Some(instant::Duration::new(0, 0)), termination_status: TerminationStatus::NotTerminated, } @@ -782,7 +799,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, PopulationState, State, ArgminFloat}; - /// # let mut state: PopulationState, f64> = PopulationState::new(); + /// # let mut state: PopulationState, f64> = PopulationState::new().counting(true); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -799,9 +816,11 @@ where /// # assert_eq!(state.counts, hm); /// ``` fn func_counts(&mut self, problem: &Problem) { - for (k, &v) in problem.counts.iter() { - let count = self.counts.entry(k.to_string()).or_insert(0); - *count = v + if self.counting_enabled { + for (k, &v) in problem.counts.iter() { + let count = self.counts.entry(k.to_string()).or_insert(0); + *count = v + } } } diff --git a/crates/argmin/src/solver/brent/brentopt.rs b/crates/argmin/src/solver/brent/brentopt.rs index ade3e820e..96f8d56ec 100644 --- a/crates/argmin/src/solver/brent/brentopt.rs +++ b/crates/argmin/src/solver/brent/brentopt.rs @@ -231,7 +231,7 @@ mod tests { let cost = TestFunc {}; let solver = BrentOpt::new(-10., 10.); let res = Executor::new(cost, solver) - .configure(|state| state.max_iters(13)) + .configure(|state| state.counting(true).max_iters(13)) .run() .unwrap(); assert_eq!( diff --git a/crates/argmin/src/solver/linesearch/backtracking.rs b/crates/argmin/src/solver/linesearch/backtracking.rs index e5acbd92c..0b2e42c68 100644 --- a/crates/argmin/src/solver/linesearch/backtracking.rs +++ b/crates/argmin/src/solver/linesearch/backtracking.rs @@ -640,7 +640,12 @@ mod tests { ls.search_direction(vec![2.0f64, 0.0]); let data = Executor::new(prob, ls.clone()) - .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10)) + .configure(|config| { + config + .counting(true) + .param(ls.init_param.clone().unwrap()) + .max_iters(10) + }) .run(); assert!(data.is_ok()); @@ -689,7 +694,12 @@ mod tests { ls.search_direction(vec![2.0f64, 0.0]); let data = Executor::new(prob, ls.clone()) - .configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10)) + .configure(|config| { + config + .counting(true) + .param(ls.init_param.clone().unwrap()) + .max_iters(10) + }) .run(); assert!(data.is_ok()); diff --git a/crates/argmin/src/tests.rs b/crates/argmin/src/tests.rs index d7bc9fa74..c7370ec52 100644 --- a/crates/argmin/src/tests.rs +++ b/crates/argmin/src/tests.rs @@ -161,7 +161,12 @@ fn test_lbfgs_func_count() { let linesearch = MoreThuenteLineSearch::new(); let solver = LBFGS::new(linesearch, 10); let res = Executor::new(cost.clone(), solver) - .configure(|config| config.param(cost.param_init.clone()).max_iters(100)) + .configure(|config| { + config + .param(cost.param_init.clone()) + .max_iters(100) + .counting(true) + }) .run() .unwrap();