-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptions.cpp
465 lines (436 loc) · 19 KB
/
options.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
#include "options.hpp"
#include "utilities.hpp"
#include "graph.hpp"
#include "network.hpp"
namespace jup {
Schaf_options global_options;
struct Parse_state {
Array<jup_str> stack;
jup_str current;
jup_str current_option;
};
static bool pop(Parse_state* state) {
assert(state);
if (not state->stack) {
return false;
} else {
state->current = state->stack.pop_back();
return true;
}
}
[[noreturn]] static void parse_die(Parse_state* state, jup_str message) {
assert(state);
jerr << "Error: " << message << '\n';
if (state->current_option) {
jerr << "Error: in option " << state->current_option << '\n';
}
jerr << "Error: while parsing command line\n";
jerr << "\nUse the --help option for usage information.\n";
std::exit(3);
}
static void pop_option_arg(Parse_state* state) {
assert(state);
if (not pop(state)) {
parse_die(state, "Unexpected end of input, expected option argument");
}
}
static int get_int(
Parse_state* state,
int lower = std::numeric_limits<int>::min(),
int upper = std::numeric_limits<int>::max()
) {
assert(state);
int value;
auto code = jup_stox(state->current, &value);
if (code) {
parse_die(state, jup_err_messages[code]);
} else if (value < lower) {
parse_die(state, jup_printf("The value is too small, must be at least %d", lower));
} else if (value > upper) {
parse_die(state, jup_printf("The value is too big, must be at most %d", upper));
} else {
return value;
}
}
static float get_float(
Parse_state* state,
float lower = -std::numeric_limits<float>::infinity(),
float upper = std::numeric_limits<float>::infinity()
) {
assert(state);
float value;
auto code = jup_stox(state->current, &value);
if (code) {
parse_die(state, jup_err_messages[code]);
} else if (value < lower) {
parse_die(state, jup_printf("The value is too small, must be at least %f", lower));
} else if (value > upper) {
parse_die(state, jup_printf("The value is too big, must be at most %f", upper));
} else {
return value;
}
}
static void print_usage() {
Buffer str;
str.append(
"Usage:\n"
" schaf [options] [--] mode [args]\n"
"\n"
"Modes:\n"
" write_graph <input> <output>\n"
" Executes the job specified in the jobfile <input>, and writes the resulting graphs "
"into the file <output>. It is recommended that <output> has the extension "
"'.schaf.lz4'.\n"
"\n"
" print_stats <input> [output]\n"
" Reads the graphs from the file <input> and prints information about them to the "
"console. If <output> is specified, the information will additionally be written to "
"that file, in a machine-readable format.\n"
"\n"
" prepare_data <input> <output>\n"
" Generates training data for the neural network by reading the graphs in <input> and "
"writes it into <output>.\n"
"\n"
" train <input>\n"
" Read the training data contained in the file <input> and train the network.\n"
"\n"
" print_data_info <input>\n"
" Reads the training data from the file <input> and prints information about it to the "
"console.\n"
"\n"
" dump_graph <input> <output> [index]\n"
" Takes a graph from <input> with index <index> and writes a gdf file (as used in "
"GUESS) describing it. That file can then be displayed by graph-visualisation tools. If "
"<index> is omitted, the first graph is taken.\n"
"\n"
" dump_graph_random <output> [seed]\n"
" Randomly generates a graph and writes a gdf file (see dump_graph) describing it. If "
"<seed> is omitted, 0 is used.\n"
"\n"
" grid_search <input>\n"
" Randomly generates sets of hyperparameters and optimises them for some time. The "
"results are printed.\n"
"\n"
" cross_validate <input>\n"
" Evaluates the network on the specified training data. Specify the parameters to use "
"via the --param-in option; note that the hyperparameters have to match the saved "
"network!\n"
"\n"
" classify <input>\n"
" Read the graphs from <input> and classify them using the network. Specify the "
"parameters to use via the --param-in option; note that the hyperparameters have to "
"match the saved network!\n"
"\n"
"Options:\n"
" --edges-min,-e <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_EDGES_MIN) "]\n"
" --edges-max,-E <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_EDGES_MAX) "]\n"
" Limits the graphs that are written to graphs with a number of edges inside the "
"specified range.\n"
"\n"
" --batch-count,-N <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_BATCH_COUNT) "]\n"
" Number of batches per training data.\n"
"\n"
" --batch-size,-n <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_BATCH_SIZE) "]\n"
" Number of instances per batch.\n"
"\n"
" --recf-nodes <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_RECF_NODES) "]\n"
" Number of nodes per receptive field.\n"
"\n"
" --recf-count <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_RECF_COUNT) "]\n"
" Number of receptive fields.\n"
"\n"
" --a1-size <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_A1_SIZE) "]\n"
" --a2-size <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_A2_SIZE) "]\n"
" --b1-size <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_B1_SIZE) "]\n"
" --b2-size <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_B2_SIZE) "]\n"
" Sizes of the different layers of the neural network.\n"
"\n"
" --gen-instances <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_GEN_INSTANCES) "]\n"
" Number of instances generated per graph. Note that these instances may use the same "
"nodes. Only relevant during mode prepare_data.\n"
"\n"
" --learning-rate,-l <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_LEARNING_RATE) "]\n"
" The initial learning rate of the network. Note that when loading a parameter file, "
"the saved learning rate will be used instead.\n"
"\n"
" --learning-rate-decay,-L <val> [default: 0]\n"
" The amount of epochs after which the learning rate is halved. Set to 0 to disable "
"learning rate decay.\n"
"\n"
" --dropout,-d <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_DROPOUT) "]\n"
" The dropout to use for the network, that is the fraction of nodes that is retained "
"while training. Set to 1.0 to disable dropout.\n"
"\n"
" --l2reg,-2 <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_L2REG) "]\n"
" The regularisation strength as applied to the l2 regularisation. Set to 0.f to disable.\n"
"\n"
" --seed,-s <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_SEED) "]\n"
" Seed to initialise tensorflow randomness. If set to 0, random randomness is used.\n"
"\n"
" --test-frac <val> [default: " JUP_STRINGIFY(JUP_DEFAULT_TEST_FRAC) "]\n"
" The fraction of the data set that is used as test data.\n"
"\n"
" --param-in,-i <path> [default: none]\n"
" The parameter file to load. It is used to initialize the networks parameters and the "
"learning rate.\n"
"\n"
" --iter-max <value> [default: none]\n"
" The maximum number of training iterations for the network.\n"
"\n"
" --iter-save <value> [default: " JUP_STRINGIFY(JUP_DEFAULT_ITER_SAVE) "]\n"
" The number of iterations after which the parameters will be saved. Set to 0 to "
"disable saving. The files will be saved in the directory specified by --logdir. If "
"that options is not set, saving of parameters will also be disabled.\n"
"\n"
" --iter-event <value> [default: " JUP_STRINGIFY(JUP_DEFAULT_ITER_EVENT) "]\n"
" The number of iterations after which a summary for tensorboard will be written. Set "
"to 0 to disable. The files will be saved in the directory specified by --logdir. If "
"that options is not set, summaries will also be disabled.\n"
"\n"
" --logdir <path> [default: " JUP_DEFAULT_LOGDIR "]\n"
" The location to write the summary logfiles (for tensorboard) and the parameter values "
"to. The directory will be created, if necessary. If this is the empty string, both "
"logging and saving of parameters are disabled.\n"
"\n"
" --grid-max-time,-T <value> [default: " JUP_STRINGIFY(JUP_DEFAULT_GRID_MAX_TIME) "]\n"
" The amount of time a chosen set of hyperparameters (during grid search) is allowed to "
"optimise, before being terminated.\n"
"\n"
" --grid-params <batch-size> <rate> <decay> <a1-size> <a2> <b1> <b2> <dropout> <l2reg>\n"
" Set all the hyperparameters at once. Useful for just copy-pasting a grid-search "
"result. You probably want to set the batch count before this.\n"
"\n"
" --samples,-S <value> [default: " JUP_STRINGIFY(JUP_DEFAULT_SAMPLES) "]\n"
" Number of times a neighbourhood will be generated for each graph in mode classify.\n"
"\n"
" --profile <path> [default: none]\n"
" Enables profiling. The results will be written to the specified location. Note that "
"profiling is "
#ifdef USE_PROFILER
"DISABLED"
#else
"ENABLED"
#endif
" in this executable. (Build with USE_PROFILER=1 to enable.)\n"
"\n"
" --help,-h\n"
" Prints this message.\n"
"\n"
);
print_wrapped(jerr, str);
}
static bool parse_option(Schaf_options* options, Parse_state* state) {
if (not pop(state)) {
parse_die(state, "Unexpected end of input, expected an option or mode.");
}
if (not state->current.size() or state->current.front() != '-') {
return false;
}
state->current_option = state->current;
if (
state->current == "--help" or state->current == "-h"
or state->current == "-?" or state->current == "/?"
) {
print_usage();
std::exit(4);
} else if (state->current == "--edges-min" or state->current == "-e") {
pop_option_arg(state);
options->graph_min_edges = get_int(state, 1);
} else if (state->current == "--edges-max" or state->current == "-E") {
pop_option_arg(state);
options->graph_max_edges = get_int(state, 1);
} else if (state->current == "--batch-count" or state->current == "-N") {
pop_option_arg(state);
options->hyp.batch_count = get_int(state, 1);
} else if (state->current == "--batch-size" or state->current == "-n") {
pop_option_arg(state);
options->hyp.batch_size = get_int(state, 1);
} else if (state->current == "--recf-nodes") {
pop_option_arg(state);
options->hyp.recf_nodes = get_int(state, 1);
} else if (state->current == "--recf-count") {
pop_option_arg(state);
options->hyp.recf_count = get_int(state, 1);
} else if (state->current == "--gen-instances") {
pop_option_arg(state);
options->hyp.gen_instances = get_int(state, 1);
} else if (state->current == "--a1-size") {
pop_option_arg(state);
options->hyp.a1_size = get_int(state, 1);
} else if (state->current == "--a2-size") {
pop_option_arg(state);
options->hyp.a2_size = get_int(state, 1);
} else if (state->current == "--b1-size") {
pop_option_arg(state);
options->hyp.b1_size = get_int(state, 1);
} else if (state->current == "--b2-size") {
pop_option_arg(state);
options->hyp.b2_size = get_int(state, 1);
} else if (state->current == "--learning-rate" or state->current == "-l") {
pop_option_arg(state);
options->hyp.learning_rate = get_float(state, 0.f);
} else if (state->current == "--learning-rate-decay" or state->current == "-L") {
pop_option_arg(state);
options->hyp.learning_rate_decay = get_int(state, 0);
} else if (state->current == "--dropout" or state->current == "-d") {
pop_option_arg(state);
options->hyp.dropout = get_float(state, 0.1f, 1.f);
} else if (state->current == "--l2reg" or state->current == "-2") {
pop_option_arg(state);
options->hyp.l2_reg = get_float(state, 0.f);
} else if (state->current == "--seed" or state->current == "-s") {
pop_option_arg(state);
options->hyp.seed = (u64)get_int(state);
} else if (state->current == "--test-frac") {
pop_option_arg(state);
options->hyp.test_frac = get_float(state, 0.f, 1.0f);
} else if (state->current == "--param-in" or state->current == "-i") {
pop_option_arg(state);
options->param_in = state->current;
} else if (state->current == "--iter-max") {
pop_option_arg(state);
options->iter_max = get_int(state, 0);
} else if (state->current == "--iter-save") {
pop_option_arg(state);
options->iter_save = get_int(state, 1);
} else if (state->current == "--iter-event") {
pop_option_arg(state);
options->iter_event = get_int(state, 0);
} else if (state->current == "--logdir") {
pop_option_arg(state);
options->logdir = state->current;
} else if (state->current == "--grid-max-time" or state->current == "-T") {
pop_option_arg(state);
options->grid_max_time = get_float(state, 0.f);
} else if (state->current == "--grid-params") {
int inst_total = options->hyp.num_instances();
pop_option_arg(state); options->hyp.batch_size = get_int(state, 1);
options->hyp.batch_count = inst_total / options->hyp.batch_size;
pop_option_arg(state); options->hyp.learning_rate = get_float(state, 0.f);
pop_option_arg(state); options->hyp.learning_rate_decay = get_int(state, 1);
pop_option_arg(state); options->hyp.a1_size = get_int(state, 1);
pop_option_arg(state); options->hyp.a2_size = get_int(state, 1);
pop_option_arg(state); options->hyp.b1_size = get_int(state, 1);
pop_option_arg(state); options->hyp.b2_size = get_int(state, 1);
pop_option_arg(state); options->hyp.dropout = get_float(state, 0.1f, 1.f);
pop_option_arg(state); options->hyp.l2_reg = get_float(state, 0.f);
} else if (state->current == "--samples") {
pop_option_arg(state);
options->samples = get_int(state, 1);
} else if (state->current == "--profile") {
pop_option_arg(state);
options->profiler_loc = state->current;
} else if (state->current == "--") {
if (not pop(state)) {
parse_die(state, "Unexpected end of input, expected a mode.");
}
return false;
} else {
parse_die(state, "Unknwon option.");
}
state->current_option = jup_str {};
return true;
}
void options_execute(Schaf_options* options, Array_view<jup_str> args) {
assert(options);
Parse_state state;
for (int i = 0; i < args.size(); ++i) {
state.stack.push_back(args[args.size() - i - 1]);
}
while (true) {
if (not parse_option(options, &state)) break;
}
Profiler_context profiler_context {options->profiler_loc.size(), options->profiler_loc, true};
if (state.current == "write_graph") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode write_graph.");
}
jup_str input = state.current;
if (not pop(&state)) {
parse_die(&state, "Expected the <output> argument to mode write_graph.");
}
jup_str output = state.current;
graph_exec_jobfile(input, output);
} else if (state.current == "print_stats") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode print_stats.");
}
jup_str input = state.current;
jup_str output;
if (pop(&state)) {
output = state.current;
}
graph_print_stats(input, output);
} else if (state.current == "prepare_data") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode prepare_data.");
}
jup_str input = state.current;
if (not pop(&state)) {
parse_die(&state, "Expected the <output> argument to mode prepare_data.");
}
jup_str output = state.current;
network_prepare_data(input, output, options->hyp);
} else if (state.current == "train") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode train.");
}
jup_str input = state.current;
network_train(input);
} else if (state.current == "print_data_info") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode print_data_info.");
}
jup_str input = state.current;
network_print_data_info(input);
} else if (state.current == "dump_graph") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode dump_graph.");
}
jup_str input = state.current;
if (not pop(&state)) {
parse_die(&state, "Expected the <output> argument to mode dump_graph.");
}
jup_str output = state.current;
int index = 0;
if (pop(&state)) {
index = get_int(&state, 0);
}
graph_dump(input, output, index);
} else if (state.current == "dump_graph_random") {
if (not pop(&state)) {
parse_die(&state, "Expected the <output> argument to mode dump_graph_random.");
}
jup_str output = state.current;
u64 seed = 0;
if (pop(&state)) {
seed = (u64)get_int(&state);
}
graph_dump_random(output, seed);
} else if (state.current == "grid_search") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode grid_search.");
}
jup_str input = state.current;
network_grid_search(input);
} else if (state.current == "cross_validate") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode cross_validate.");
}
jup_str input = state.current;
network_cross_validate(input);
} else if (state.current == "classify") {
if (not pop(&state)) {
parse_die(&state, "Expected the <input> argument to mode classify.");
}
jup_str input = state.current;
network_classify(input);
} else {
auto s = jup_printf(
"Unknown mode: \"%s\"",
state.current
);
parse_die(&state, s);
}
}
} /* end of namespace jup */