Skip to content

Commit

Permalink
wip(stn): tests and exploitation of relevant focus
Browse files Browse the repository at this point in the history
  • Loading branch information
arbimo committed Nov 12, 2024
1 parent d9df1fd commit 742bc95
Showing 1 changed file with 235 additions and 59 deletions.
294 changes: 235 additions & 59 deletions solver/src/reasoners/stn/theory/distances.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,141 @@ impl Default for DijkstraState {
#[cfg(test)]
mod test {
use crate::core::IntCst;
use itertools::Itertools;
use rand::prelude::SeedableRng;
use rand::prelude::SmallRng;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::iter::once;

#[derive(Eq, PartialEq)]
struct Reverse<'a, G: Graph>(&'a G);

impl<'a, G: Graph> Graph for Reverse<'a, G> {
fn vertices(&self) -> impl Iterator<Item = V> + '_ {
self.0.vertices()
}

fn outgoing(&self, src: V) -> impl Iterator<Item = Edge> + '_ {
self.0.incoming(src).map(|e| Edge::new(e.tgt, e.src, e.weight))
}

fn incoming(&self, src: V) -> impl Iterator<Item = Edge> + '_ {
self.0.outgoing(src).map(|e| Edge::new(e.tgt, e.src, e.weight))
}
}

pub trait Graph {
fn vertices(&self) -> impl Iterator<Item = V> + '_;
fn outgoing(&self, src: V) -> impl Iterator<Item = Edge> + '_;
fn incoming(&self, src: V) -> impl Iterator<Item = Edge> + '_;

fn relevants(&self, new_edge: &Edge) -> Vec<(V, IntCst)> {
dbg!(new_edge);
let mut relevants = Vec::new();
let mut visited = HashSet::new();
let mut heap = BinaryHeap::new();

let mut best_label: HashMap<V, Label> = HashMap::new();

// order allows to override the label of the target edge if the edge is a self loop
let tgt_lbl = Label::new(new_edge.weight, true);
best_label.insert(new_edge.tgt, tgt_lbl);
heap.push((tgt_lbl, new_edge.tgt));
let src_lbl = Label::new(0, false);
best_label.insert(new_edge.src, src_lbl);
heap.push((src_lbl, new_edge.src));

// count of the number of unvisited relevants in the queue
let mut remaining_relevants: u32 = 1;

while let Some((lbl @ Label { dist, relevant }, curr)) = heap.pop() {
if visited.contains(&curr) {
// already treated, ignore
continue;
}
visited.insert(curr);
debug_assert_eq!(lbl, best_label[&curr]);
if relevant {
// there is a shortest path through new edge to v
relevants.push((curr, dist));
remaining_relevants -= 1;
}
for out in self.outgoing(curr) {
let lbl = Label::new(dist + out.weight, relevant);

if let Some(previous_label) = best_label.get(&out.tgt) {
if previous_label >= &lbl {
debug_assert!(previous_label.dist <= lbl.dist);
continue; // no improvement, ignore
}
if previous_label.relevant && !lbl.relevant {
remaining_relevants -= 1
} else if !previous_label.relevant && lbl.relevant {
remaining_relevants += 1;
}
} else if lbl.relevant {
remaining_relevants += 1;
}
best_label.insert(out.tgt, lbl);
heap.push((lbl, out.tgt));
}
if remaining_relevants == 0 {
// there is no hope of finding new relevants;
break;
}
}

relevants
}

fn potentially_updated_paths(&self, new_edge: &Edge) -> Vec<Edge>
where
Self: Sized,
{
let mut updated_paths = Vec::with_capacity(32);
let relevants_after = self.relevants(new_edge);
let reversed = Reverse(self);
let relevants_before = reversed.relevants(&new_edge.reverse());

for (end, cost_from_src) in relevants_after {
for (orig, cost_to_tgt) in relevants_before.iter().copied() {
updated_paths.push(Edge {
src: orig,
tgt: end,
weight: cost_to_tgt - new_edge.weight + cost_from_src,
})
}
}
updated_paths
}

fn ssp(&self, src: V, tgt: V) -> Option<IntCst> {
let mut visited = HashSet::new();
// this is a max heap, so we will store the negation of computed distances
let mut heap = BinaryHeap::new();

heap.push((-0, src));

while let Some((neg_dist, curr)) = heap.pop() {
if visited.contains(&curr) {
// already treated, ignore
continue;
}
visited.insert(curr);
if curr == tgt {
return Some(-neg_dist);
}
for out in self.outgoing(curr) {
let lbl = neg_dist - out.weight;
heap.push((lbl, out.tgt));
}
}
None
}
}

#[derive(Eq, PartialEq, Copy, Clone, Debug)]
struct Label {
dist: IntCst,
relevant: bool,
Expand Down Expand Up @@ -168,65 +299,60 @@ mod test {
type V = u32;
type L = IntCst;

#[derive(Debug, Copy, Clone)]
struct Edge {
src: V,
tgt: V,
label: L,
weight: L,
}

impl Edge {
pub fn new(src: V, tgt: V, label: L) -> Self {
Self { src, tgt, label }
}
}

fn succs(edges: &[Edge], src: V) -> impl Iterator<Item = &Edge> + '_ {
edges.iter().filter(move |e| e.src == src)
}

fn relevants(g: &[Edge], new_edge: &Edge) -> Vec<V> {
let mut relevants = Vec::new();
let mut visited = HashSet::new();
let mut heap = BinaryHeap::new();

heap.push((Label::new(0, false), new_edge.src));
heap.push((Label::new(new_edge.label, true), new_edge.tgt));

while let Some((Label { dist, relevant }, curr)) = heap.pop() {
if visited.contains(&curr) {
// already treated, ignore
continue;
Self {
src,
tgt,
weight: label,
}
visited.insert(curr);
if relevant {
// there is a shortest path through new edge to v
relevants.push(curr)
}
for out in succs(g, curr) {
let lbl = Label::new(dist + out.label, relevant);
heap.push((lbl, out.tgt));
}
pub fn reverse(self) -> Self {
Self {
src: self.tgt,
tgt: self.src,
weight: self.weight,
}
}

relevants
}

fn ssp(g: &[Edge], src: V, tgt: V) -> Option<IntCst> {
// this is a max heap, so we will store the negation of computed distances
let mut heap = BinaryHeap::new();

heap.push((-0, src));
impl Graph for &[Edge] {
fn vertices(&self) -> impl Iterator<Item = V> + '_ {
self.iter()
.flat_map(|e| once(e.src).chain(once(e.tgt)))
.sorted()
.unique()
}
fn outgoing(&self, v: V) -> impl Iterator<Item = Edge> + '_ {
self.iter().copied().filter(move |e| e.src == v)
}
fn incoming(&self, v: V) -> impl Iterator<Item = Edge> + '_ {
self.iter().copied().filter(move |e| e.tgt == v)
}
}

while let Some((neg_dist, curr)) = heap.pop() {
if curr == tgt {
return Some(-neg_dist);
}
for out in succs(g, curr) {
let lbl = neg_dist - out.label;
heap.push((lbl, out.tgt));
}
fn gen_graph(seed: u64) -> Vec<Edge> {
let mut graph = Vec::new();
let mut rng = SmallRng::seed_from_u64(seed);
let num_vertices = rng.gen_range(4..5);
let num_edges = rng.gen_range(2..=6);

for _ in 0..num_edges {
let src = rng.gen_range(0..num_vertices);
let tgt = rng.gen_range(0..num_vertices);
let weight = rng.gen_range(0..10);
let edge = Edge { src, tgt, weight };
graph.push(edge)
}
None

graph
}

#[test]
Expand All @@ -239,23 +365,73 @@ mod test {
Edge::new(2, 4, 1),
];

assert_eq!(ssp(g, 1, 2), Some(1));
assert_eq!(ssp(g, 1, 3), Some(4));
assert_eq!(ssp(g, 1, 4), Some(2));
assert_eq!(g.ssp(1, 2), Some(1));
assert_eq!(g.ssp(1, 3), Some(4));
assert_eq!(g.ssp(1, 4), Some(2));

let graphs = vec![g];
let graphs = (0..1000).map(gen_graph).collect_vec();

for graph in graphs {
let original_graph = &graph[1..];
let added_edge = &graph[0];
let final_graph = graph;
let updated = relevants(original_graph, added_edge);

for up in updated {
let previous = ssp(original_graph, added_edge.src, up);
let new = ssp(final_graph, added_edge.src, up).unwrap();
println!("{up}: {previous:?} -> {new}");
assert!(previous.is_none() || previous.unwrap() > new);
let final_graph = graph.as_slice();

let updated = original_graph.relevants(added_edge);
let updated: HashMap<V, IntCst> = updated.into_iter().collect();
println!("{:?}", final_graph);
dbg!("{:?}", updated.clone());

for other in final_graph.vertices() {
let previous = original_graph.ssp(added_edge.src, other);
let new = final_graph.ssp(added_edge.src, other);
let new_sp = match (previous, new) {
(Some(previous), Some(new)) => new < previous,
(None, Some(_new)) => true,
(Some(_), None) => panic!("A path disapeared ?"),
_ => false,
};
let present_in_updated = updated.contains_key(&other);
assert_eq!(new_sp, present_in_updated, "{:?} -> {:?}", added_edge.src, other);
if present_in_updated {
assert_eq!(
updated[&other],
new.unwrap(),
"The length of the shortest paths should be the same"
);
}
}
}

// assert_eq!(relevants(&g[1..=3], &g[0]), vec! {2});
// assert_eq!(relevants(&g[1..=4], &g[0]), vec! {2, 4});
}

#[test]
fn test_graph_updates() {
let graphs = (0..1000).map(gen_graph).collect_vec();

for graph in graphs {
let original_graph = &graph[1..];
let added_edge = &graph[0];
let final_graph = graph.as_slice();

let updated_paths = original_graph.potentially_updated_paths(added_edge);
let updated_paths: HashMap<(V, V), IntCst> =
updated_paths.into_iter().map(|e| ((e.src, e.tgt), e.weight)).collect();

for orig in final_graph.vertices() {
for dest in final_graph.vertices() {
let previous = original_graph.ssp(orig, dest);
let new = final_graph.ssp(orig, dest);
let new_sp = match (previous, new) {
(Some(previous), Some(new)) => new < previous,
(None, Some(_new)) => true,
(Some(_), None) => panic!("A path disapeared ?"),
_ => false,
};
let present_in_updated = updated_paths.contains_key(&(orig, dest));
assert!(!new_sp || present_in_updated); // new_sp => present_in_updated
}
}
}

Expand Down

0 comments on commit 742bc95

Please sign in to comment.