Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some minor optimizations #483

Merged
merged 2 commits into from
Mar 30, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Disable counting by default
  • Loading branch information
pacak committed Mar 29, 2024
commit 79a274a60eef97e3e6da86771efeb6de296a52a6
14 changes: 14 additions & 0 deletions crates/argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
@@ -971,6 +971,20 @@ where
pub fn take_prev_residuals(&mut self) -> Option<R> {
self.prev_residuals.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{IterState, State};
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/linearprogramstate.rs
Original file line number Diff line number Diff line change
@@ -56,6 +56,8 @@ pub struct LinearProgramState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
@@ -150,6 +152,20 @@ impl<P, F> LinearProgramState<P, F> {
self.cost = cost;
self
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, LinearProgramState};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for LinearProgramState<P, F>
@@ -205,6 +221,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
@@ -503,7 +520,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
@@ -520,9 +537,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/populationstate.rs
Original file line number Diff line number Diff line change
@@ -58,6 +58,8 @@ pub struct PopulationState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
@@ -429,6 +431,20 @@ where
pub fn take_population(&mut self) -> Option<Vec<P>> {
self.population.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, PopulationState};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for PopulationState<P, F>
@@ -483,6 +499,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
@@ -782,7 +799,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
@@ -799,9 +816,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

2 changes: 1 addition & 1 deletion crates/argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
@@ -231,7 +231,7 @@ mod tests {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(13))
.configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
14 changes: 12 additions & 2 deletions crates/argmin/src/solver/linesearch/backtracking.rs
Original file line number Diff line number Diff line change
@@ -640,7 +640,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

@@ -689,7 +694,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

7 changes: 6 additions & 1 deletion crates/argmin/src/tests.rs
Original file line number Diff line number Diff line change
@@ -161,7 +161,12 @@ fn test_lbfgs_func_count() {
let linesearch = MoreThuenteLineSearch::new();
let solver = LBFGS::new(linesearch, 10);
let res = Executor::new(cost.clone(), solver)
.configure(|config| config.param(cost.param_init.clone()).max_iters(100))
.configure(|config| {
config
.param(cost.param_init.clone())
.max_iters(100)
.counting(true)
})
.run()
.unwrap();

Loading