Skip to content

Commit

Permalink
feat(sat): Add alternating mode (stable/focused) to sat solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
arbimo committed Oct 1, 2024
1 parent c74b19a commit 23768fb
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 14 deletions.
39 changes: 32 additions & 7 deletions examples/sat/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use anyhow::*;
use aries::core::Lit;
use aries::model::lang::expr::or;
use aries::solver::parallel::{ParSolver, SolverResult};
use aries::solver::search::combinators::WithGeomRestart;
use aries::solver::search::combinators::{RoundRobin, WithGeomRestart};
use aries::solver::search::conflicts::{ConflictBasedBrancher, Params};
use aries::solver::search::SearchControl;
use aries::solver::Solver;
use std::collections::HashMap;
use std::fs::File;
Expand Down Expand Up @@ -106,18 +107,42 @@ fn solve_multi_threads(model: Model, opt: &Opt, deadline: Option<Instant>) -> Re
let search_params: Vec<_> = opt.search.split(",").collect();
let num_threads = search_params.len();

let mut par_solver = ParSolver::new(solver, num_threads, |id, solver| {
let choices = choices.clone();
let conflict_params = |conf: &str| {
let mut params = Params::default();
for opt in search_params[id].split(":") {
for opt in conf.split(":") {
let handled = params.configure(opt);
if !handled {
panic!("UNSUPPORTED OPTION: {opt}")
}
}
let brancher = Box::new(ConflictBasedBrancher::with(choices, params));
let brancher = WithGeomRestart::new(100, 1.2, brancher);
solver.set_brancher(brancher);
params
};

let mut par_solver = ParSolver::new(solver, num_threads, |id, solver| {
let search_params: Vec<_> = search_params[id].split("/").collect();
let stable_params = if search_params.len() > 0 {
search_params[0]
} else {
"+lrb:+p+l:-neg"
};
let focused_params = if search_params.len() > 1 {
search_params[1]
} else {
"+lrb:+p:+neg"
};
let choices = choices.clone();

let stable_params = conflict_params(stable_params);
let stable_brancher = Box::new(ConflictBasedBrancher::with(choices.clone(), stable_params));
let stable_brancher = WithGeomRestart::new(5000, 1.2, stable_brancher).clone_to_box();

let focused_params = conflict_params(focused_params);
let focused_brancher = Box::new(ConflictBasedBrancher::with(choices, focused_params));
let focused_brancher = WithGeomRestart::new(400, 1.0, focused_brancher).clone_to_box();

let round_robin = RoundRobin::new(10_000, 1.1, vec![stable_brancher, focused_brancher]);

solver.set_brancher(round_robin);
});

match par_solver.solve(deadline) {
Expand Down
99 changes: 99 additions & 0 deletions solver/src/solver/search/combinators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::model::extensions::SavedAssignment;
use crate::model::Model;
use crate::solver::search::{Brancher, Decision, SearchControl};
use crate::solver::stats::Stats;
use itertools::Itertools;
use std::sync::Arc;

/// A trait that provides extension methods for branchers
Expand Down Expand Up @@ -275,3 +276,101 @@ impl<L: 'static> SearchControl<L> for WithGeomRestart<L> {
})
}
}

/// A solver that alternates between the given strategies in a round-robin fashion.
pub struct RoundRobin<L> {
/// Number of conflicts before switching to the next
num_conflicts_per_period: u64,
/// Factor by witch to multiply the number of conflicts/period after a switch
increase_factor: f64,
num_conflicts_since_switch: u64,
/// Index of the current brancher.
current_idx: usize,
branchers: Vec<Brancher<L>>,
}

impl<L> RoundRobin<L> {
pub fn new(num_conflicts_per_period: u64, increase_factor: f64, branchers: Vec<Brancher<L>>) -> Self {
RoundRobin {
num_conflicts_per_period,
increase_factor,
num_conflicts_since_switch: 0,
current_idx: 0,
branchers,
}
}
fn current(&self) -> &Brancher<L> {
&self.branchers[self.current_idx]
}
fn current_mut(&mut self) -> &mut Brancher<L> {
&mut self.branchers[self.current_idx]
}
}

impl<L> Backtrack for RoundRobin<L> {
fn save_state(&mut self) -> DecLvl {
self.current_mut().save_state()
}

fn num_saved(&self) -> u32 {
self.current().num_saved()
}

fn restore_last(&mut self) {
self.current_mut().restore_last();

// we are at the ROOT, check if we should switch to the next brancher
if self.num_saved() == 0 && self.num_conflicts_since_switch >= self.num_conflicts_per_period {
self.current_idx = (self.current_idx + 1) % self.branchers.len();
self.num_conflicts_since_switch = 0;
self.num_conflicts_per_period = (self.num_conflicts_per_period as f64 * self.increase_factor) as u64;
}
}
}

impl<L: 'static> SearchControl<L> for RoundRobin<L> {
fn next_decision(&mut self, stats: &Stats, model: &Model<L>) -> Option<Decision> {
self.current_mut().next_decision(stats, model)
}

fn import_vars(&mut self, model: &Model<L>) {
self.current_mut().import_vars(model)
}

fn new_assignment_found(&mut self, objective_value: IntCst, assignment: Arc<SavedAssignment>) {
self.current_mut().new_assignment_found(objective_value, assignment)
}

fn pre_save_state(&mut self, _model: &Model<L>) {
self.current_mut().pre_save_state(_model);
}

fn pre_conflict_analysis(&mut self, _model: &Model<L>) {
self.current_mut().pre_conflict_analysis(_model);
}

fn conflict(
&mut self,
clause: &Conflict,
model: &Model<L>,
explainer: &mut dyn Explainer,
backtrack_level: DecLvl,
) {
self.num_conflicts_since_switch += 1;
self.current_mut().conflict(clause, model, explainer, backtrack_level)
}

fn asserted_after_conflict(&mut self, lit: Lit, model: &Model<L>) {
self.current_mut().asserted_after_conflict(lit, model)
}

fn clone_to_box(&self) -> Brancher<L> {
Box::new(Self {
num_conflicts_per_period: self.num_conflicts_per_period,
increase_factor: self.increase_factor,
num_conflicts_since_switch: self.num_conflicts_since_switch,
current_idx: self.current_idx,
branchers: self.branchers.iter().map(|b| b.clone_to_box()).collect_vec(),
})
}
}
15 changes: 8 additions & 7 deletions solver/src/solver/search/conflicts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,14 @@ impl<Var> SearchControl<Var> for ConflictBasedBrancher {
}
}

let lbd = lbd(clause, &model.state);
let impact = match self.params.impact_measure {
ImpactMeasure::Unit => 1.0,
ImpactMeasure::LBD => 1f64 + 1f64 / lbd as f64,
ImpactMeasure::LogLBD => 1f64 + 1f64 / (1f64 + (lbd as f64).log2()),
ImpactMeasure::SearchSpaceReduction => 1f64 / 2f64.powf(lbd as f64),
};

// we have identified all culprits, update the heuristic information (depending on the heuristic used)
for culprit in culprits.literals() {
match self.params.heuristic {
Expand All @@ -704,13 +712,6 @@ impl<Var> SearchControl<Var> for ConflictBasedBrancher {
self.bump_activity(culprit, model);
}
Heuristic::LearningRate => {
let lbd = lbd(clause, &model.state);
let impact = match self.params.impact_measure {
ImpactMeasure::Unit => 1.0,
ImpactMeasure::LBD => 1f64 + 1f64 / lbd as f64,
ImpactMeasure::LogLBD => 1f64 + 1f64 / (1f64 + (lbd as f64).log2()),
ImpactMeasure::SearchSpaceReduction => 1f64 / 2f64.powf(lbd as f64),
};
// learning rate branching, record that the variable participated in thus conflict
// the variable's priority will be updated upon backtracking
let v = culprit.variable();
Expand Down

0 comments on commit 23768fb

Please sign in to comment.