diff --git a/solver/src/reasoners/eq/mod.rs b/solver/src/reasoners/eq/mod.rs index 86506206..730fc05d 100644 --- a/solver/src/reasoners/eq/mod.rs +++ b/solver/src/reasoners/eq/mod.rs @@ -2,12 +2,13 @@ use crate::backtrack::{Backtrack, DecLvl, EventIndex, ObsTrailCursor, Trail}; use crate::collections::ref_store::RefMap; use crate::core::literals::Watches; use crate::core::state::{Cause, Domains, Explanation, InvalidUpdate}; -use crate::core::{Lit, VarRef}; +use crate::core::{Lit, SignedVar, VarRef}; use crate::model::{Label, Model}; use crate::reasoners::{Contradiction, ReasonerId, Theory}; use crate::reif::ReifExpr; use itertools::Itertools; use std::collections::{HashMap, HashSet}; +use std::ops::Shl; #[derive(Copy, Clone, Debug)] struct OutEdge { @@ -71,7 +72,34 @@ struct DirEdgeLabel { #[derive(Clone, Debug)] enum Event { - Propagation { x: VarRef, y: VarRef, z: VarRef }, + EdgePropagation { x: VarRef, y: VarRef, z: VarRef }, +} + +enum InferenceCause { + EdgePropagation(EventIndex), + UBPropagation { source: SignedVar }, +} + +impl From for u32 { + fn from(value: InferenceCause) -> Self { + match value { + InferenceCause::EdgePropagation(e) => 0u32 + (u32::from(e) << 1), + InferenceCause::UBPropagation { source } => 1u32 + (u32::from(source) << 1), + } + } +} + +impl From for InferenceCause { + fn from(value: u32) -> Self { + let kind = value & 0x1; + let payload = value >> 1; + match value & 0x1 { + 0 => InferenceCause::EdgePropagation(EventIndex::from(payload)), + 1 => InferenceCause::UBPropagation { + source: SignedVar::from(payload), + }, + } + } } #[derive(Clone, Default)] @@ -148,7 +176,7 @@ fn set_edge_label( Some(true) => Ok(false), _ => { // there might be a change, record event source to be able to explain it - let event = Event::Propagation { x, y, z }; + let event = Event::EdgePropagation { x, y, z }; let id = trail.push(event); let cause = Cause::inference(ReasonerId::Eq, id); domains.set(label, cause) @@ -312,6 +340,78 @@ impl EqTheory { } Ok(()) } + + pub fn propagate_ub_change(&mut self, v: SignedVar, domains: &mut Domains) -> Result<(), InvalidUpdate> { + let ub = domains.get_bound(v); + let cause = Cause::inference(ReasonerId::Eq, InferenceCause::UBPropagation { source: v }); + + for out in &self.graph.succs[v.variable()] { + if domains.entails(out.active) { + if domains.entails(out.label) { + let svar = if v.is_plus() { + SignedVar::plus(out.succ) + } else { + SignedVar::minus(out.succ) + }; + domains.set_bound(svar, ub, cause)?; + } + } + } + Ok(()) + } + + pub fn propagate_eq_domains_lr( + &mut self, + left: VarRef, + right: VarRef, + domains: &mut Domains, + ) -> Result<(), InvalidUpdate> { + debug_assert!(domains.entails(self.graph.active(left, right))); + debug_assert!(domains.entails(self.graph.label(left, right))); + + // enforce upper bound by capping inverse var + let left = SignedVar::plus(left); + let right = SignedVar::plus(right); + let cause = Cause::inference(ReasonerId::Eq, InferenceCause::UBPropagation { source: left }); + let ub = domains.get_bound(left); + domains.set_bound(right, ub, cause)?; + + // enforce lower bound by capping inverse var + let left = -left; + let right = -right; + let cause = Cause::inference(ReasonerId::Eq, InferenceCause::UBPropagation { source: left }); + let ub = domains.get_bound(left); + domains.set_bound(right, ub, cause)?; + + Ok(()) + } + + pub fn propagate_neq_domains_lr( + &mut self, + left: VarRef, + right: VarRef, + domains: &mut Domains, + ) -> Result<(), InvalidUpdate> { + debug_assert!(domains.entails(self.graph.active(left, right))); + debug_assert!(domains.entails(!self.graph.label(left, right))); + + // enforce upper bound by capping inverse var + let left = SignedVar::plus(left); + let right = SignedVar::plus(right); + let cause = Cause::inference(ReasonerId::Eq, InferenceCause::UBPropagation { source: left }); + let ub = domains.get_bound(left); + domains.set_bound(right, ub, cause)?; + + // enforce lower bound by capping inverse var + let left = -left; + let right = -right; + let cause = Cause::inference(ReasonerId::Eq, InferenceCause::UBPropagation { source: left }); + let ub = domains.get_bound(left); + domains.set_bound(right, ub, cause)?; + + Ok(()) + } + pub fn add_edge(&mut self, a: VarRef, b: VarRef, model: &mut impl ReifyEq) -> Lit { self.graph.add_node(a, model); self.graph.add_node(b, model); @@ -356,7 +456,7 @@ impl Theory for EqTheory { fn explain(&mut self, _: Lit, context: u32, domains: &Domains, out_explanation: &mut Explanation) { let event_index = EventIndex::from(context); let event = self.trail.get_event(event_index); - let &Event::Propagation { x, y, z } = event; + let &Event::EdgePropagation { x, y, z } = event; let mut push_causes = |a, b| { let ab_act = self.graph.active(a, b);