diff --git a/Cargo.toml b/Cargo.toml index 0564cc9e..14739721 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "examples/knapsack", "validator", ] +resolver = "2" [workspace.dependencies] anyhow = { version = "1.0"} @@ -22,6 +23,7 @@ regex = { version = "1" } tracing = { version = "0.1", features = ["release_max_level_debug"] } tracing-subscriber = "0.3" itertools = { version = "0.11.0" } +rand = { version = "0.8.5", features = ["small_rng"] } [profile.dev] opt-level = 0 diff --git a/solver/Cargo.toml b/solver/Cargo.toml index 8890e123..8b8eef2a 100644 --- a/solver/Cargo.toml +++ b/solver/Cargo.toml @@ -29,6 +29,7 @@ smallvec = "1.4.2" num-integer = { default-features = false, version = "0.1.44" } tracing = { workspace = true } lru = "0.10.0" +rand = { workspace = true } [dev-dependencies] rand = "0.8" diff --git a/solver/src/backtrack/queues.rs b/solver/src/backtrack/queues.rs index 987b5e91..8832978d 100644 --- a/solver/src/backtrack/queues.rs +++ b/solver/src/backtrack/queues.rs @@ -214,6 +214,7 @@ impl ObsTrail { ObsTrailCursor { next_read: EventIndex::from(0u32), last_backtrack: None, + pristine: true, _phantom: Default::default(), } } @@ -420,6 +421,7 @@ pub struct TrailEvent<'a, V> { pub struct ObsTrailCursor { next_read: EventIndex, last_backtrack: Option, + pristine: bool, _phantom: PhantomData, } @@ -437,10 +439,19 @@ impl ObsTrailCursor { ObsTrailCursor { next_read: EventIndex::from(0u32), last_backtrack: None, + pristine: true, _phantom: Default::default(), } } + pub fn is_pristine(&self) -> bool { + self.pristine + } + + pub fn mark_used(&mut self) { + self.pristine = false + } + // TODO: check correctness if more than one backtrack occurred between two synchronisations fn sync_backtrack(&mut self, queue: &ObsTrail) { if let Some(x) = &queue.last_backtrack { @@ -468,6 +479,7 @@ impl ObsTrailCursor { } pub fn pop<'q>(&mut self, queue: &'q ObsTrail) -> Option<&'q V> { + self.mark_used(); self.sync_backtrack(queue); let next = self.next_read; @@ -480,6 +492,7 @@ impl ObsTrailCursor { } pub fn move_to_end(&mut self, queue: &ObsTrail) { + self.mark_used(); self.sync_backtrack(queue); self.next_read = queue.next_slot(); } diff --git a/solver/src/core/state/int_domains.rs b/solver/src/core/state/int_domains.rs index 0246e2f4..b5f7e2c9 100644 --- a/solver/src/core/state/int_domains.rs +++ b/solver/src/core/state/int_domains.rs @@ -102,10 +102,12 @@ impl IntDomains { new_value: new, previous: current, }; + // println!("UPDATE: {lit:?} {cause:?}"); self.events.push(event); // update occurred and is consistent Ok(true) } else { + // println!("INVALID UPDATE: {lit:?} {cause:?}"); Err(InvalidUpdate(lit, cause)) } } diff --git a/solver/src/core/variable.rs b/solver/src/core/variable.rs index fcb53a0f..dac97831 100644 --- a/solver/src/core/variable.rs +++ b/solver/src/core/variable.rs @@ -5,13 +5,13 @@ use std::{fmt::Debug, hash::Hash}; /// Type representing an integer constant. pub type IntCst = i32; -/// Overflow tolerant min value for integer constants. -/// It is used as a default for the lower bound of integer variable domains -pub const INT_CST_MIN: IntCst = IntCst::MIN / 4; - /// Overflow tolerant max value for integer constants. /// It is used as a default for the upper bound of integer variable domains -pub const INT_CST_MAX: IntCst = IntCst::MAX / 4; +pub const INT_CST_MAX: IntCst = IntCst::MAX / 4 - 1; + +/// Overflow tolerant min value for integer constants. +/// It is used as a default for the lower bound of integer variable domains +pub const INT_CST_MIN: IntCst = -INT_CST_MAX; create_ref_type!(VarRef); diff --git a/solver/src/model/extensions/format.rs b/solver/src/model/extensions/format.rs index 017b776f..cd90d480 100644 --- a/solver/src/model/extensions/format.rs +++ b/solver/src/model/extensions/format.rs @@ -4,6 +4,7 @@ use crate::model::lang::{Atom, FAtom, IAtom, IVar, Kind, SAtom, Type}; use crate::model::symbols::{SymId, SymbolTable}; use crate::model::types::TypeId; use crate::model::ModelShape; +use crate::reif::{DifferenceExpression, ReifExpr}; use crate::utils::input::Sym; use crate::utils::Fmt; @@ -40,6 +41,10 @@ where fn get_symbol_table(&self) -> &SymbolTable { &self.get_shape().symbols } + + fn get_reified_expr(&self, lit: Lit) -> Option<&ReifExpr> { + self.get_shape().expressions.original(lit) + } } /// Wraps an atom into a custom object that can be formatted with the standard library `Display` @@ -78,6 +83,8 @@ fn format_impl_bool(ctx: &impl Shaped, b: Lit, f: &mut std::fmt write!(f, "true") } else if b == Lit::FALSE { write!(f, "false") + } else if let Some(reified) = ctx.get_reified_expr(b) { + format_reif(ctx, reified, f) } else if let Some(Type::Bool) = tpe { if b == t { format_impl_var(ctx, b.variable(), Kind::Bool, f) @@ -153,3 +160,40 @@ fn format_impl_var( write!(f, "{}{}", prefix, usize::from(v)) } } + +fn format_reif(ctx: &impl Shaped, e: &ReifExpr, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match e { + ReifExpr::Lit(l) => format_impl_bool(ctx, *l, f), + ReifExpr::MaxDiff(DifferenceExpression { a, b, ub }) => { + format_impl_var(ctx, *b, Kind::Sym, f)?; + write!(f, " - ")?; + format_impl_var(ctx, *a, Kind::Sym, f)?; + write!(f, " <= {ub}") + } + ReifExpr::Eq(v1, v2) => { + format_impl_var(ctx, *v1, Kind::Sym, f)?; + write!(f, " = ")?; + format_impl_var(ctx, *v2, Kind::Sym, f) + } + ReifExpr::Neq(v1, v2) => { + format_impl_var(ctx, *v1, Kind::Sym, f)?; + write!(f, " != ")?; + format_impl_var(ctx, *v2, Kind::Sym, f) + } + ReifExpr::EqVal(v1, v2) => { + format_impl_var(ctx, *v1, Kind::Sym, f)?; + let sym_id = SymId::from_u32(*v2 as u32); + let sym = ctx.get_symbol(sym_id); + write!(f, " = {sym}") + } + ReifExpr::NeqVal(v1, v2) => { + format_impl_var(ctx, *v1, Kind::Sym, f)?; + let sym_id = SymId::from_u32(*v2 as u32); + let sym = ctx.get_symbol(sym_id); + write!(f, " != {sym}") + } + ReifExpr::Or(_) => todo!(), + ReifExpr::And(_) => todo!(), + ReifExpr::Linear(_) => todo!(), + } +} diff --git a/solver/src/model/lang/reification.rs b/solver/src/model/lang/reification.rs index b042d8a4..359ecec5 100644 --- a/solver/src/model/lang/reification.rs +++ b/solver/src/model/lang/reification.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; pub struct Reification { /// Associates each canonical atom to a single literal. map: HashMap, + inv: HashMap, } impl Reification { @@ -25,7 +26,13 @@ impl Reification { pub fn intern_as(&mut self, e: ReifExpr, lit: Lit) { assert!(!self.map.contains_key(&e)); self.map.insert(e.clone(), lit); - self.map.insert(!e, !lit); + self.map.insert(!e.clone(), !lit); + self.inv.insert(lit, e.clone()); + self.inv.insert(!lit, !e.clone()); + } + + pub fn original(&self, lit: Lit) -> Option<&ReifExpr> { + self.inv.get(&lit) } } diff --git a/solver/src/reasoners/eq/domain.rs b/solver/src/reasoners/eq/domain.rs index 23f92612..9dbc4cba 100644 --- a/solver/src/reasoners/eq/domain.rs +++ b/solver/src/reasoners/eq/domain.rs @@ -35,15 +35,24 @@ impl Domain { self.value_literals.push(lit); } - pub fn get(&self, value: IntCst) -> Lit { - debug_assert!(self.bounds().contains(&value)); - self.value_literals[(value - self.first_value) as usize] + pub fn get(&self, value: IntCst) -> Option { + if !self.bounds().contains(&value) { + None + } else { + Some(self.value_literals[(value - self.first_value) as usize]) + } } fn values(&self, first: IntCst, last: IntCst) -> &[Lit] { - let first = (first - self.first_value) as usize; - let last = (last - self.first_value) as usize; - &self.value_literals[first..=last] + let first = (first as i64 - self.first_value as i64).max(0) as usize; + if let Ok(last) = usize::try_from(last as i64 - self.first_value as i64) { + let last = last.min(self.value_literals.len() - 1); + &self.value_literals[first..=last] + } else { + // last is before the start of the slice + // return empty slice + &self.value_literals[0..0] + } } } @@ -75,7 +84,7 @@ impl Domains { self.neq_watches.watches_on(l) } - pub fn value(&self, v: SignedVar, value: IntCst) -> Lit { + pub fn value(&self, v: SignedVar, value: IntCst) -> Option { let dom = &self.domains[&v.variable()]; if v.is_plus() { dom.get(value) diff --git a/solver/src/reasoners/eq/mod.rs b/solver/src/reasoners/eq/mod.rs index bb1592c7..7fb478f0 100644 --- a/solver/src/reasoners/eq/mod.rs +++ b/solver/src/reasoners/eq/mod.rs @@ -3,7 +3,7 @@ mod domain; use crate::backtrack::{Backtrack, DecLvl, EventIndex, ObsTrailCursor, Trail}; use crate::core::literals::Watches; use crate::core::state::{Cause, Domains, Explanation, InvalidUpdate}; -use crate::core::{IntCst, Lit, SignedVar, UpperBound, VarRef}; +use crate::core::{IntCst, Lit, SignedVar, UpperBound, VarRef, INT_CST_MAX}; use crate::model::{Label, Model}; use crate::reasoners::{Contradiction, ReasonerId, Theory}; use crate::reif::ReifExpr; @@ -425,18 +425,20 @@ impl EqTheory { domains.set_ub(var, value - 1, cause)?; } } + if self.graph.domains.has_domain(v.variable()) { for &invalid in self.graph.domains.values(v, new_ub + 1, previous_ub) { let cause = if v.is_plus() { Cause::inference(ReasonerId::Eq, InferenceCause::DomUpper) } else { + // dbg!(invalid, v, new_ub + 1, previous_ub); Cause::inference(ReasonerId::Eq, InferenceCause::DomLower) }; domains.set(!invalid, cause)?; } let mut updated_ub = new_ub; - loop { - let l = self.graph.domains.value(v, updated_ub); + + while let Some(l) = self.graph.domains.value(v, updated_ub) { if domains.entails(!l) { updated_ub -= 1; } else { @@ -450,8 +452,10 @@ impl EqTheory { let v = v.variable(); if domains.lb(v) == domains.ub(v) { let cause = Cause::inference(ReasonerId::Eq, InferenceCause::DomSingleton); - let l = self.graph.domains.value(SignedVar::plus(v), domains.lb(v)); - domains.set(l, cause)?; + if let Some(l) = self.graph.domains.value(SignedVar::plus(v), domains.ub(v)) { + // dbg!(l, v, domains.lb(v)); + domains.set(l, cause)?; + } } } Ok(()) @@ -486,6 +490,21 @@ impl Theory for EqTheory { } fn propagate(&mut self, domains: &mut Domains) -> Result<(), Contradiction> { + if self.cursor.is_pristine() { + // self.cursor.move_to_end(domains.trail()); + let vars = domains + .variables() + .flat_map(|v| [SignedVar::plus(v), SignedVar::minus(v)]) + .collect_vec(); + // + for v in vars { + let ub = domains.get_bound(v); + let new_lit = Lit::from_parts(v, ub); + self.propagate_edge_event(new_lit, domains)?; + self.propagate_domain_event(v, ub.as_int(), INT_CST_MAX, domains)?; + } + } + let mut cursor_copy = self.cursor.clone(); loop { let mut new_event_treated = false; @@ -493,7 +512,10 @@ impl Theory for EqTheory { while let Some(ev) = self.cursor.pop(domains.trail()) { if let Some(inference) = ev.cause.as_external_inference() { if inference.writer == self.identity() { - continue; // already handled during propagation + let cause = InferenceCause::from(inference.payload); + if let InferenceCause::EdgePropagation(_) = cause { + continue; // already handled during propagation + } } }; @@ -519,7 +541,7 @@ impl Theory for EqTheory { Ok(()) } - fn explain(&mut self, _: Lit, context: u32, domains: &Domains, out_explanation: &mut Explanation) { + fn explain(&mut self, l: Lit, context: u32, domains: &Domains, out_explanation: &mut Explanation) { let cause = InferenceCause::from(context); match cause { InferenceCause::EdgePropagation(event_index) => { @@ -542,12 +564,13 @@ impl Theory for EqTheory { push_causes(x, y); push_causes(y, z); } - _ => todo!(), - // InferenceCause::DomUpper => {} + x => { + dbg!(x, l); + todo!() + } // InferenceCause::DomUpper => {} // InferenceCause::DomLower => {} - // InferenceCause::DomNeq => {} - // InferenceCause::DomEq => {} - // InferenceCause::DomSingleton => {} + InferenceCause::DomNeq => {} // InferenceCause::DomEq => {} + // InferenceCause::DomSingleton => {} } } @@ -620,12 +643,15 @@ mod tests { use crate::model::lang::expr::eq; use crate::model::symbols::SymbolTable; use crate::model::types::TypeHierarchy; - use crate::model::Model; + use crate::model::{Label, Model}; use crate::reasoners::eq::{EqTheory, InferenceCause, Node, Pair, ReifyEq}; use crate::reasoners::{Contradiction, Theory}; + use crate::solver::search::random::RandomChoice; use crate::solver::Solver; use crate::utils::input::Sym; use itertools::Itertools; + use rand::prelude::SmallRng; + use rand::{Rng, SeedableRng}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -757,6 +783,8 @@ mod tests { let bc = theory.add_edge(b, c, eqs); let ac = theory.add_edge(a, c, eqs); + theory.propagate(domains).unwrap(); + domains.save_state(); theory.save_state(); domains.set(ab, Cause::Decision).expect("Invalid decision"); @@ -818,20 +846,93 @@ mod tests { #[test] fn test_model() { - let symbols = SymbolTable::from(vec![("obj", vec!["a", "b", "c"])]); + let symbols = SymbolTable::from(vec![("obj", vec!["alice", "bob", "chloe"])]); let symbols = Arc::new(symbols); let obj = symbols.types.id_of("obj").unwrap(); let mut model: Model = Model::new_with_symbols(symbols.clone()); - let x = model.new_sym_var(obj, "X"); - let y = model.new_sym_var(obj, "Y"); - let _z = model.new_sym_var(obj, "Z"); + let vars = ["V", "W", "X", "Y", "Z"] + .map(|var_name| model.new_sym_var(obj, var_name)) + .iter() + .copied() + .collect_vec(); + + for (xi, x) in vars.iter().copied().enumerate() { + for &y in &vars[xi..] { + model.reify(eq(x, y)); + } + } + + random_solves(&model, 10, Some(true)); + } + + fn random_solves(model: &Model, num_solves: u64, mut expected_result: Option) { + for seed in 0..num_solves { + let model = model.clone(); + let solver = &mut Solver::new(model); + solver.set_brancher(RandomChoice::new(seed)); + let solution = solver.solve().unwrap().is_some(); + if let Some(expected_sat) = expected_result { + assert_eq!(solution, expected_sat) + } + // ensure that the next run has the same output + expected_result = Some(solution) + } + } + + fn random_model(seed: u64) -> Model { + let mut rng = SmallRng::seed_from_u64(seed); + let objects = vec!["alice", "bob", "chloe", "donald", "elon"]; + let num_objects = rng.gen_range(1..5); + let objects = objects[0..num_objects].to_vec(); + let symbols = SymbolTable::from(vec![("obj", objects.clone())]); + let symbols = Arc::new(symbols); + + let obj = symbols.types.id_of("obj").unwrap(); - let _xy = model.reify(eq(x, y)); + let mut model: Model = Model::new_with_symbols(symbols.clone()); - let solver = &mut Solver::new(model); - solver.solve().unwrap(); + let num_scopes = rng.gen_range(0..3); + let scopes = (0..=num_scopes) + .into_iter() + .map(|i| { + if i == 0 { + Lit::TRUE + } else { + model.new_presence_variable(Lit::TRUE, format!("scope_{i}")).true_lit() + } + }) + .collect_vec(); + + let num_vars = rng.gen_range(0..10); + println!("Problem num_scopes: {num_scopes}, num_vars: {num_vars} num_values: {num_objects}"); + + let mut vars = Vec::with_capacity(num_vars); + for i in 0..num_vars { + let scope_id = rng.gen_range(0..scopes.len()); + let scope = scopes[scope_id]; + let var_name = format!("x{i}"); + println!(" {var_name} [{scope_id}] in {:?}", &objects); + let var = model.new_optional_sym_var(obj, scope, var_name); + vars.push(var) + } + + for (xi, x) in vars.iter().copied().enumerate() { + for &y in &vars[xi..] { + model.reify(eq(x, y)); + } + } + + model + } + + #[test] + fn random_problems() { + for seed in 0..100 { + let model = random_model(seed); + random_solves(&model, 30, Some(true)); + } } #[test] diff --git a/solver/src/solver/search.rs b/solver/src/solver/search.rs index 252348d1..42883339 100644 --- a/solver/src/solver/search.rs +++ b/solver/src/solver/search.rs @@ -2,6 +2,7 @@ pub mod activity; pub mod combinators; pub mod conflicts; pub mod lexical; +pub mod random; use crate::backtrack::Backtrack; use crate::core::state::{Conflict, Explainer}; diff --git a/solver/src/solver/search/random.rs b/solver/src/solver/search/random.rs new file mode 100644 index 00000000..cb596243 --- /dev/null +++ b/solver/src/solver/search/random.rs @@ -0,0 +1,76 @@ +use crate::backtrack::{Backtrack, DecLvl, DecisionLevelTracker}; +use crate::core::Lit; +use crate::model::extensions::AssignmentExt; +use crate::model::Model; +use crate::solver::search::{Decision, SearchControl}; +use crate::solver::stats::Stats; +use itertools::Itertools; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +/// Assigns all values in lexical order to their minimal value. +/// Essentially intended to finish the search once all high-priority variables have been set. +#[derive(Clone)] +pub struct RandomChoice { + rng: SmallRng, + lvl: DecisionLevelTracker, +} + +impl RandomChoice { + pub fn new(seed: u64) -> Self { + RandomChoice { + rng: SmallRng::seed_from_u64(seed), + lvl: Default::default(), + } + } +} + +impl Backtrack for RandomChoice { + fn save_state(&mut self) -> DecLvl { + self.lvl.save_state() + } + + fn num_saved(&self) -> u32 { + self.lvl.num_saved() + } + + fn restore_last(&mut self) { + self.lvl.restore_last() + } +} + +impl SearchControl for RandomChoice { + fn next_decision(&mut self, _stats: &Stats, model: &Model) -> Option { + // set the first domain value of the first unset variable + let variables = model + .state + .variables() + .filter(|v| { + if model.state.present(*v) == Some(true) { + let dom = model.var_domain(*v); + !dom.is_bound() + } else { + false + } + }) + .collect_vec(); + if variables.is_empty() { + return None; + } + let var_id = self.rng.gen_range(0..variables.len()); + let var = variables[var_id]; + let (lb, ub) = model.state.bounds(var); + let upper: bool = self.rng.gen(); + if upper { + let val = self.rng.gen_range(lb..ub); + Some(Decision::SetLiteral(Lit::leq(var, val))) + } else { + let val = self.rng.gen_range((lb + 1)..=ub); + Some(Decision::SetLiteral(Lit::geq(var, val))) + } + } + + fn clone_to_box(&self) -> Box + Send> { + Box::new(self.clone()) + } +}