Skip to content

Commit

Permalink
Separate extractor with timeout and without timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Jan 1, 2024
1 parent 36f05d0 commit 8c1f9d7
Showing 1 changed file with 100 additions and 88 deletions.
188 changes: 100 additions & 88 deletions src/extract/ilp_cbc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ This extractor is simple so that it's easy to see that it's correct.
If the timeout is reached, it will return the result of the faster-greedy-dag extractor.
*/

// Without a timeout, some will take > 10 hours to finish.
const SOLVING_TIME_LIMIT_SECONDS: u64 = 10;

use super::*;
use coin_cbc::{Col, Model, Sense};
use indexmap::IndexSet;
Expand All @@ -18,108 +15,123 @@ struct ClassVars {
}

pub struct CbcExtractor;
pub struct CbcExtractorWithTimeout;

impl Extractor for CbcExtractorWithTimeout {
fn extract(egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
// Without a timeout, some will take > 10 hours to finish.
const SOLVING_TIME_LIMIT_SECONDS: u32 = 10;
return extract(egraph, roots, SOLVING_TIME_LIMIT_SECONDS);
}
}

impl Extractor for CbcExtractor {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
let mut model = Model::default();

model.set_parameter("seconds", &SOLVING_TIME_LIMIT_SECONDS.to_string());

let vars: IndexMap<ClassId, ClassVars> = egraph
.classes()
.values()
.map(|class| {
let cvars = ClassVars {
active: model.add_binary(),
nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
};
(class.id.clone(), cvars)
})
.collect();

for (class_id, class) in &vars {
// class active == some node active
// sum(for node_active in class) == class_active
let row = model.add_row();
model.set_row_equal(row, 0.0);
model.set_weight(row, class.active, -1.0);
for &node_active in &class.nodes {
model.set_weight(row, node_active, 1.0);
}
return extract(egraph, roots, std::u32::MAX);
}
}

fn extract(egraph: &EGraph, roots: &[ClassId], timeout_seconds: u32) -> ExtractionResult {
let mut model = Model::default();

let childrens_classes_var = |nid: NodeId| {
egraph[&nid]
.children
.iter()
.map(|n| egraph[n].eclass.clone())
.map(|n| vars[&n].active)
.collect::<IndexSet<_>>()
model.set_parameter("seconds", &timeout_seconds.to_string());

let vars: IndexMap<ClassId, ClassVars> = egraph
.classes()
.values()
.map(|class| {
let cvars = ClassVars {
active: model.add_binary(),
nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
};
(class.id.clone(), cvars)
})
.collect();

for (class_id, class) in &vars {
// class active == some node active
// sum(for node_active in class) == class_active
let row = model.add_row();
model.set_row_equal(row, 0.0);
model.set_weight(row, class.active, -1.0);
for &node_active in &class.nodes {
model.set_weight(row, node_active, 1.0);
}

for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
for child_active in childrens_classes_var(node_id.clone()) {
// node active implies child active, encoded as:
// node_active <= child_active
// node_active - child_active <= 0
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, node_active, 1.0);
model.set_weight(row, child_active, -1.0);
}
let childrens_classes_var = |nid: NodeId| {
egraph[&nid]
.children
.iter()
.map(|n| egraph[n].eclass.clone())
.map(|n| vars[&n].active)
.collect::<IndexSet<_>>()
};

for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
for child_active in childrens_classes_var(node_id.clone()) {
// node active implies child active, encoded as:
// node_active <= child_active
// node_active - child_active <= 0
let row = model.add_row();
model.set_row_upper(row, 0.0);
model.set_weight(row, node_active, 1.0);
model.set_weight(row, child_active, -1.0);
}
}
}

model.set_obj_sense(Sense::Minimize);
for class in egraph.classes().values() {
for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
let node = &egraph[node_id];
let node_cost = node.cost.into_inner();
assert!(node_cost >= 0.0);
model.set_obj_sense(Sense::Minimize);
for class in egraph.classes().values() {
for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
let node = &egraph[node_id];
let node_cost = node.cost.into_inner();
assert!(node_cost >= 0.0);

if node_cost != 0.0 {
model.set_obj_coeff(node_active, node_cost);
}
if node_cost != 0.0 {
model.set_obj_coeff(node_active, node_cost);
}
}
}

for root in roots {
model.set_col_lower(vars[root].active, 1.0);
}
for root in roots {
model.set_col_lower(vars[root].active, 1.0);
}

block_cycles(&mut model, &vars, &egraph);

let solution = model.solve();
log::info!(
"CBC status {:?}, {:?}, obj = {}",
solution.raw().status(),
solution.raw().secondary_status(),
solution.raw().obj_value(),
);

if solution.raw().status() != coin_cbc::raw::Status::Finished {
let initial_result =
super::faster_greedy_dag::FasterGreedyDagExtractor.extract(egraph, roots);
log::info!("Unfinished CBC solution");
return initial_result;
}
block_cycles(&mut model, &vars, &egraph);

let mut result = ExtractionResult::default();

for (id, var) in &vars {
let active = solution.col(var.active) > 0.0;
if active {
let node_idx = var
.nodes
.iter()
.position(|&n| solution.col(n) > 0.0)
.unwrap();
let node_id = egraph[id].nodes[node_idx].clone();
result.choose(id.clone(), node_id);
}
}
let solution = model.solve();
log::info!(
"CBC status {:?}, {:?}, obj = {}",
solution.raw().status(),
solution.raw().secondary_status(),
solution.raw().obj_value(),
);

return result;
if solution.raw().status() != coin_cbc::raw::Status::Finished {
assert(timeout != std::u32::MAX);

let initial_result =
super::faster_greedy_dag::FasterGreedyDagExtractor.extract(egraph, roots);
log::info!("Unfinished CBC solution");
return initial_result;
}

let mut result = ExtractionResult::default();

for (id, var) in &vars {
let active = solution.col(var.active) > 0.0;
if active {
let node_idx = var
.nodes
.iter()
.position(|&n| solution.col(n) > 0.0)
.unwrap();
let node_id = egraph[id].nodes[node_idx].clone();
result.choose(id.clone(), node_id);
}
}

return result;
}

fn block_cycles(model: &mut Model, vars: &IndexMap<ClassId, ClassVars>, egraph: &EGraph) {
Expand Down

0 comments on commit 8c1f9d7

Please sign in to comment.