-
Notifications
You must be signed in to change notification settings - Fork 14
/
multicut_solver_options.h
128 lines (122 loc) · 6.54 KB
/
multicut_solver_options.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once
#include <CLI/CLI.hpp>
struct multicut_solver_options {
std::string input_file;
std::string output_sol_file="";
int max_cycle_length_lb = 5;
int num_dual_itr_lb = 10;
int max_cycle_length_primal = 3;
int num_dual_itr_primal = 5;
int num_outer_itr_dual = 1;
float mean_multiplier_mm = 0.05;
float matching_thresh_crossover_ratio = 0.05;
float tri_memory_factor = 2.0;
bool only_compute_lb = false;
int max_time_sec = -1;
bool dump_timeline = false;
bool verbose = true;
bool sanitize_graph = false;
multicut_solver_options() { }
multicut_solver_options(const std::string& solver_type) {
if(solver_type == "PD")
{
std::cout<<"Running solver type PD which offers best compute time versus quality tradeoff."<<std::endl;
}
else if(solver_type == "P")
{
std::cout<<"Running purely primal solver (better runtime, worse quality)"<<std::endl;
max_cycle_length_lb = 0;
num_dual_itr_lb = 0;
max_cycle_length_primal = 0;
num_dual_itr_primal = 0;
}
else if(solver_type == "PD+")
{
std::cout<<"Running PD+ solver (worse runtime, better quality)"<<std::endl;
max_cycle_length_lb = 5;
num_dual_itr_lb = 10;
max_cycle_length_primal = 5;
num_dual_itr_primal = 10;
}
else if(solver_type == "D")
{
std::cout<<"Running dual solver to compute only the lower bound."<<std::endl;
max_cycle_length_lb = 5;
num_dual_itr_lb = 10;
num_outer_itr_dual = 5;
only_compute_lb = true;
}
else
std::runtime_error("invalid solver_type specified.");
}
multicut_solver_options(
const int _max_cycle_length_lb,
const int _num_dual_itr_lb,
const int _max_cycle_length_primal,
const int _num_dual_itr_primal,
const int _num_outer_itr_dual,
const float _mean_multiplier_mm,
const float _matching_thresh_crossover_ratio,
const float _tri_memory_factor,
const bool _only_compute_lb,
const int _max_time_sec,
const bool _dump_timeline = false,
const bool _sanitize_graph = false) :
max_cycle_length_lb(_max_cycle_length_lb),
num_dual_itr_lb(_num_dual_itr_lb),
max_cycle_length_primal(_max_cycle_length_primal),
num_dual_itr_primal(_num_dual_itr_primal),
num_outer_itr_dual(_num_outer_itr_dual),
mean_multiplier_mm(_mean_multiplier_mm),
matching_thresh_crossover_ratio(_matching_thresh_crossover_ratio),
tri_memory_factor(_tri_memory_factor),
only_compute_lb(_only_compute_lb),
max_time_sec(_max_time_sec),
dump_timeline(_dump_timeline),
sanitize_graph(_sanitize_graph)
{}
int from_cl(int argc, char** argv) {
CLI::App app{"Solver for multicut problem. "};
app.add_option("-f,--file,file_pos", input_file, "Path to multicut instance (.txt)")->required()->check(CLI::ExistingPath);
app.add_option("-o,--out_sol_file", output_sol_file, "Path to save node labeling (.txt)");
app.add_option("max_cycle_dual", max_cycle_length_lb, "Maximum length of conflicted cycles to consider for initial dual updates. (Default: 5).")->check(CLI::Range(0, 5));
app.add_option("dual_itr", num_dual_itr_lb, "Number of dual update iterations per cycle. (Default: 10).")->check(CLI::NonNegativeNumber);
app.add_option("max_cycle_primal", max_cycle_length_primal, "Maximum length of conflicted cycles to consider during primal iterations for reparameterization. (Default: 3).")->check(CLI::Range(0, 5));
app.add_option("dual_itr_primal", num_dual_itr_primal, "Number of dual update iterations per cycle during primal reparametrization. (Default: 5).")->check(CLI::NonNegativeNumber);
app.add_option("dual_itr_outer", num_outer_itr_dual, "Number of outer dual iterations for initial dual updates. Larger number detects conflicted cycles again. (Default: 1).")->check(CLI::NonNegativeNumber);
app.add_option("mean_multiplier_mm", mean_multiplier_mm, "Match the edges which have cost more than mean(pos edges) * mean_multiplier_mm.")->check(CLI::NonNegativeNumber);
app.add_option("matching_thresh_crossover_ratio", matching_thresh_crossover_ratio, "Ratio of (# contract edges / # nodes ) at which to change from maximum matching based contraction to MST based. "
"(Default: 0.1). Greater than 1 will always use MST.")->check(CLI::NonNegativeNumber);
app.add_option("tri_memory_factor", tri_memory_factor,
"Average number of triangles per repulsive edge. (Used for memory allocation. Use lesser value in-case of out of memory errors during dual solve). (Default: 2.0).")->check(CLI::PositiveNumber);
app.add_flag("--only_lb", only_compute_lb, "Only compute the lower bound. (Default: false).");
app.add_flag("--dump_timeline", dump_timeline, "Return the output of each contraction step. Only use for debugging/visualization purposes. (slow). (Default: false).");
app.add_flag("--sanitize_graph", sanitize_graph, "If the input graph contains nodes without any edges and thus needs sanitizing. Cluster labels in this case will be -1 for these nodes. (Default: false).");
try {
app.parse(argc, argv);
return -1;
} catch (const CLI::ParseError &e) {
return app.exit(e);
}
}
std::string get_string() const
{
return std::string("<multicut_solver_options>:") +
"max_cycle_length_lb: " + std::to_string(max_cycle_length_lb) +
", num_dual_itr_lb: " + std::to_string(num_dual_itr_lb) +
", num_dual_itr_primal: " + std::to_string(num_dual_itr_primal) +
", max_cycle_length_primal: " + std::to_string(max_cycle_length_primal) +
", num_outer_itr_dual: " + std::to_string(num_outer_itr_dual) +
", mean_multiplier_mm: " + std::to_string(mean_multiplier_mm) +
", matching_thresh_crossover_ratio: " + std::to_string(matching_thresh_crossover_ratio) +
", tri_memory_factor: " + std::to_string(tri_memory_factor) +
", only_compute_lb: " + std::to_string(only_compute_lb) +
", max_time_sec: " + std::to_string(max_time_sec) +
", sanitize_graph: " + std::to_string(sanitize_graph) +
", verbose: " + std::to_string(verbose) + "\n";
}
void print() const
{
std::cout<<this->get_string();
}
};