Skip to content

Commit

Permalink
chore(solver): Refactor to avoid using BoundValue
Browse files Browse the repository at this point in the history
  • Loading branch information
arbimo committed Nov 14, 2024
1 parent 2f62b1e commit fd03e76
Show file tree
Hide file tree
Showing 20 changed files with 258 additions and 258 deletions.
17 changes: 13 additions & 4 deletions planning/planning/src/chronicles/concrete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::Arc;

use crate::chronicles::constraints::Constraint;
use crate::chronicles::Fluent;
use aries::core::{IntCst, Lit, VarRef};
use aries::core::{IntCst, Lit, SignedVar, VarRef};
use aries::model::lang::linear::{LinearSum, LinearTerm};
use aries::model::lang::*;

Expand Down Expand Up @@ -55,8 +55,17 @@ pub trait Substitution {
}

fn sub_lit(&self, b: Lit) -> Lit {
let (var, rel, val) = b.unpack();
Lit::new(self.sub_var(var), rel, val)
let svar = b.svar();
// substitute the variable
let new_var = self.sub_var(svar.variable());
// reapply the sign
let new_svar = if svar.is_plus() {
SignedVar::plus(new_var)
} else {
SignedVar::minus(new_var)
};
// reconstruct the literal
Lit::leq(new_svar, b.ub_value())
}

fn sub_linear_term(&self, term: &LinearTerm) -> LinearTerm {
Expand Down Expand Up @@ -183,7 +192,7 @@ impl Sub {
pub fn add_bool_expr_unification(&mut self, param: Lit, instance: Lit) -> Result<(), InvalidSubstitution> {
if param == instance {
Ok(())
} else if param.relation() == instance.relation() && param.value() == instance.value() {
} else if param.relation() == instance.relation() && param.ub_value() == instance.ub_value() {
self.add_untyped(param.variable(), instance.variable())
} else {
Err(InvalidSubstitution::IncompatibleStructures(
Expand Down
100 changes: 50 additions & 50 deletions solver/src/core/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ use std::cmp::Ordering;
/// - the bound `x > 0` represent the true literal (`X` takes the value `true`)
/// - the bound `x <= 0` represents the false literal (`X` takes the value `false`)
///
/// The struct is opaque as it is internal representation is optimized to allow more efficient usage.
/// To access individual fields the methods `variable()`, `relation()` and `value()` can be used.
/// The `unpack()` method extract all fields into a tuple.
///
/// ```
/// use aries::core::*;
Expand All @@ -23,20 +20,33 @@ use std::cmp::Ordering;
/// let x_is_false: Lit = !x_is_true;
/// let y = state.new_var(0, 10);
/// let y_geq_5 = Lit::geq(y, 5);
/// ```
///
/// # Representation
///
/// Internally, a literal is represented as an upper bound on a signed variable.
///
/// // the `<=` is internally converted into a `<`
/// assert_eq!(y_geq_5.variable(), y);
/// assert_eq!(y_geq_5.relation(), Relation::Gt);
/// assert_eq!(y_geq_5.value(), 4);
/// assert_eq!(y_geq_5.unpack(), (y, Relation::Gt, 4));
/// - var <= 5 -> var <= 5
/// - var < 5 -> var <= 4
/// - var >= 3 -> -var <= -3
/// - var > 3 -> -var <= -4
/// ```
/// use aries::core::*;
/// use aries::core::state::IntDomains;
/// let mut state = IntDomains::new();
/// let x = state.new_var(0, 1);
/// assert_eq!(x.leq(5), SignedVar::plus(x).leq(5));
/// assert_eq!(Lit::lt(x, 5), SignedVar::plus(x).leq(4));
/// assert_eq!(Lit::geq(x, 3), SignedVar::minus(x).leq(-3));
/// assert_eq!(Lit::gt(x, 3), SignedVar::minus(x).leq(-4));
/// ```
///
/// # Ordering
///
/// `Lit` defines a very specific order, which is equivalent to sorting the result of the `unpack()` method.
/// The different fields are compared in the following order to define the ordering:
/// - variable
/// - relation
/// - sign of the variable
/// - value
///
/// As a result, ordering a vector of `Lit`s will group them by variable, then among literals on the same variable by relation.
Expand All @@ -56,7 +66,7 @@ pub struct Lit {
svar: SignedVar,
/// Upper bound of the signed variable.
/// This design allows to test entailment without testing the relation of the Bound
upper_bound: UpperBound,
upper_bound: IntCst,
}

#[derive(Ord, PartialOrd, Eq, PartialEq, Debug, Copy, Clone)]
Expand All @@ -77,30 +87,12 @@ impl std::fmt::Display for Relation {
impl Lit {
/// A literal that is always true. It is defined by stating that the special variable [VarRef::ZERO] is
/// lesser than or equal to 0, which is always true.
pub const TRUE: Lit = Lit::new(VarRef::ZERO, Relation::Leq, 0);
pub const TRUE: Lit = Lit::new(SignedVar::plus(VarRef::ZERO), 0);
/// A literal that is always false. It is defined as the negation of [Lit::TRUE].
pub const FALSE: Lit = Lit::TRUE.not();

#[inline]
pub const fn from_parts(var_bound: SignedVar, value: UpperBound) -> Self {
Lit {
svar: var_bound,
upper_bound: value,
}
}

#[inline]
pub const fn new(variable: VarRef, relation: Relation, value: IntCst) -> Self {
match relation {
Relation::Leq => Lit {
svar: SignedVar::plus(variable),
upper_bound: UpperBound::ub(value),
},
Relation::Gt => Lit {
svar: SignedVar::minus(variable),
upper_bound: UpperBound::lb(value + 1),
},
}
pub const fn new(svar: SignedVar, upper_bound: IntCst) -> Lit {
Lit { svar, upper_bound }
}

#[inline]
Expand All @@ -117,11 +109,12 @@ impl Lit {
}
}

#[inline]
pub const fn value(self) -> IntCst {
match self.relation() {
Relation::Leq => self.upper_bound.as_int(),
Relation::Gt => -self.upper_bound.as_int() - 1, // TODO: this appear misleading
pub fn unpack(self) -> (VarRef, Relation, IntCst) {
if self.svar.is_plus() {
(self.svar.variable(), Relation::Leq, self.upper_bound)
} else {
// -var <= ub <=> var >= -ub <=> var > -ub -1
(self.svar.variable(), Relation::Gt, -self.upper_bound - 1)
}
}

Expand All @@ -131,13 +124,16 @@ impl Lit {
}

#[inline]
pub const fn bound_value(self) -> UpperBound {
pub const fn ub_value(self) -> IntCst {
self.upper_bound
}

#[inline]
pub fn leq(var: impl Into<SignedVar>, val: IntCst) -> Lit {
Lit::from_parts(var.into(), UpperBound::ub(val))
Lit {
svar: var.into(),
upper_bound: val,
}
}
#[inline]
pub fn lt(var: impl Into<SignedVar>, val: IntCst) -> Lit {
Expand Down Expand Up @@ -168,7 +164,7 @@ impl Lit {
// !(x <= d) <=> x > d <=> x >= d+1 <= -x <= -d -1
Lit {
svar: self.svar.neg(),
upper_bound: UpperBound::ub(-self.upper_bound.as_int() - 1),
upper_bound: -self.upper_bound - 1,
}
}

Expand All @@ -187,11 +183,7 @@ impl Lit {
/// ```
#[inline]
pub fn entails(self, other: Lit) -> bool {
self.svar == other.svar && self.upper_bound.stronger(other.upper_bound)
}

pub fn unpack(self) -> (VarRef, Relation, IntCst) {
(self.variable(), self.relation(), self.value())
self.svar == other.svar && self.upper_bound <= other.upper_bound
}

/// An ordering that will group literals by (given from highest to lowest priority):
Expand Down Expand Up @@ -242,13 +234,21 @@ impl std::fmt::Debug for Lit {
Lit::TRUE => write!(f, "true"),
Lit::FALSE => write!(f, "false"),
_ => {
let (var, rel, val) = self.unpack();
if rel == Relation::Gt && val == 0 {
write!(f, "l{}", var.to_u32())
} else if rel == Relation::Leq && val == 0 {
write!(f, "!l{}", var.to_u32())
let var = self.svar().variable();
if self.svar().is_plus() {
let upper_bound = self.upper_bound;
if upper_bound == 0 {
write!(f, "!l{:?}", var.to_u32())
} else {
write!(f, "{var:?} <= {upper_bound}")
}
} else {
write!(f, "{var:?} {rel} {val}")
let lb = -self.upper_bound;
if lb == 1 {
write!(f, "l{:?}", var.to_u32())
} else {
write!(f, "{lb:?} <= {var:?}")
}
}
}
}
Expand Down
17 changes: 6 additions & 11 deletions solver/src/core/literals/disjunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@ impl Disjunction {
let l2 = self.literals[i + 1];
debug_assert!(l1 < l2, "clause is not sorted");
if l1.variable() == l2.variable() {
debug_assert_eq!(l1.relation(), Relation::Gt);
debug_assert_eq!(l2.relation(), Relation::Leq);
let x = l1.value();
let y = l2.value();
// we have the disjunction var > x || var <= y
// if y > x, all values of var satisfy one of the disjuncts
if y >= x {
debug_assert!(l1.svar().is_minus());
debug_assert!(l2.svar().is_plus());
if (!l1).entails(l2) || (!l2).entails(l1) {
// all values of var satisfy one of the disjuncts
return true;
}
}
Expand Down Expand Up @@ -170,7 +167,7 @@ impl DisjunctionBuilder {

pub fn push(&mut self, lit: Lit) {
let sv = lit.svar();
let ub = lit.bound_value().as_int();
let ub = lit.ub_value();
let new_ub = if let Some(prev) = self.upper_bounds.get(&sv) {
// (sv <= ub) || (sv <= prev) <=> (sv <= max(ub, prev))
ub.max(*prev)
Expand All @@ -181,9 +178,7 @@ impl DisjunctionBuilder {
}

pub fn literals(&self) -> impl Iterator<Item = Lit> + '_ {
self.upper_bounds
.iter()
.map(|(k, v)| Lit::from_parts(*k, UpperBound::ub(*v)))
self.upper_bounds.iter().map(|(k, v)| Lit::leq(*k, *v))
}
}

Expand Down
19 changes: 10 additions & 9 deletions solver/src/core/literals/lit_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use std::fmt::{Debug, Formatter};
/// ```
#[derive(Clone, Default)]
pub struct LitSet {
elements: HashMap<SignedVar, UpperBound>,
/// List of signed vars and their corresponding upper bound
elements: HashMap<SignedVar, IntCst>,
}

impl LitSet {
Expand All @@ -51,23 +52,23 @@ impl LitSet {
}

pub fn literals(&self) -> impl Iterator<Item = Lit> + '_ {
self.elements.iter().map(|(var, val)| Lit::from_parts(*var, *val))
self.elements.iter().map(|(var, val)| var.leq(*val))
}

pub fn contains(&self, elem: Lit) -> bool {
self.elements
.get(&elem.svar())
.map_or(false, |b| b.stronger(elem.bound_value()))
.map_or(false, |ub| *ub <= elem.ub_value())
}

/// Insert a literal `lit` into the set.
///
/// Note that all literals directly implied by `lit` are also implicitly inserted.
pub fn insert(&mut self, lit: Lit) {
#[allow(clippy::or_fun_call)]
let val = self.elements.entry(lit.svar()).or_insert(lit.bound_value());
if lit.bound_value().strictly_stronger(*val) {
*val = lit.bound_value()
let val = self.elements.entry(lit.svar()).or_insert(lit.ub_value());
if lit.ub_value() < *val {
*val = lit.ub_value()
}
}

Expand All @@ -87,19 +88,19 @@ impl LitSet {
///
pub fn remove(&mut self, rm: Lit, tautology: impl Fn(Lit) -> bool) {
debug_assert!(self.contains(rm));
let weaker = Lit::from_parts(rm.svar(), rm.bound_value() + BoundValueAdd::RELAXATION);
let weaker = rm.svar().leq(rm.ub_value() + 1);
if tautology(weaker) {
self.elements.remove(&rm.svar());
} else {
self.elements.insert(rm.svar(), weaker.bound_value());
self.elements.insert(rm.svar(), weaker.ub_value());
}
}
}

impl Debug for LitSet {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_set()
.entries(self.elements.iter().map(|(svar, ub)| svar.with_upper_bound(*ub)))
.entries(self.elements.iter().map(|(svar, ub)| svar.leq(*ub)))
.finish()
}
}
Expand Down
17 changes: 11 additions & 6 deletions solver/src/core/literals/watches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ impl<Watcher> WatchSet<Watcher> {
pub fn add_watch(&mut self, watcher: Watcher, literal: Lit) {
self.watches.push(Watch {
watcher,
guard: literal.bound_value(),
guard: literal.ub_value(),
})
}

pub fn len(&self) -> usize {
self.watches.len()
}

pub fn is_empty(&self) -> bool {
self.watches.is_empty()
}

pub fn clear(&mut self) {
self.watches.clear();
}
Expand All @@ -44,15 +48,15 @@ impl<Watcher> WatchSet<Watcher> {
{
self.watches
.iter()
.any(|w| w.watcher == watcher && literal.bound_value().stronger(w.guard))
.any(|w| w.watcher == watcher && literal.ub_value() <= w.guard)
}

pub fn watches_on(&self, literal: Lit) -> impl Iterator<Item = Watcher> + '_
where
Watcher: Copy,
{
self.watches.iter().filter_map(move |w| {
if literal.bound_value().stronger(w.guard) {
if literal.ub_value() <= w.guard {
Some(w.watcher)
} else {
None
Expand All @@ -67,7 +71,7 @@ impl<Watcher> WatchSet<Watcher> {
pub fn move_watches_to(&mut self, literal: Lit, out: &mut WatchSet<Watcher>) {
let mut i = 0;
while i < self.watches.len() {
if literal.bound_value().stronger(self.watches[i].guard) {
if literal.ub_value() <= self.watches[i].guard {
let w = self.watches.swap_remove(i);
out.watches.push(w);
} else {
Expand All @@ -86,11 +90,12 @@ impl<Watcher> Default for WatchSet<Watcher> {
#[derive(Copy, Clone)]
pub struct Watch<Watcher> {
pub watcher: Watcher,
guard: UpperBound,
/// upper bound
guard: IntCst,
}
impl<Watcher> Watch<Watcher> {
pub fn to_lit(&self, var_bound: SignedVar) -> Lit {
Lit::from_parts(var_bound, self.guard)
var_bound.leq(self.guard)
}
}

Expand Down
Loading

0 comments on commit fd03e76

Please sign in to comment.