-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_configs.py
69 lines (56 loc) · 2.23 KB
/
create_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Script for generating and saving configurations for model training.
This script creates configurations by taking a cartesian product of various hyperparameters,
shuffles them, sorts them based on the 'run' key, and then writes them to a JSON file.
"""
import itertools
import json
import random
def main():
n_runs = 5
# train_sizes = [100, 250, 500, 750, 900]
train_sizes = [100, 250, 500, 900]
# train_sizes = [250]
models = ['Baseline', 'BaselineDropout']
loss_functions = ['mse', 'mae', 'ssim', 'weighted_mse', 'weighted_mae', 'total_mse', 'total_mae', 'total_ssim']
parameters_product = itertools.product(train_sizes, models, loss_functions)
configs = []
for train_size, model, loss_function in parameters_product:
for i in range(n_runs):
config = {
'run': i,
'data': {'type': 'vector', 'train_size': train_size, 'test_size': 100},
'model': model,
'loss_function': loss_function,
'optimizer': 'adam',
'encoding': 'none',
'positional_encoding': 0,
'epochs': 300,
'early_stop_patience': 60,
}
# configs.append(config)
models = ['UNet', 'CFPNetM']
encodings = ['domain']
pos_encodings = [1, 2]
loss_functions = ['mae', 'ssim', 'weighted_mse', 'weighted_mae', 'total_mse', 'total_mae', 'total_ssim']
parameters_product = itertools.product(train_sizes, models, loss_functions)
for train_size, model, loss_function in parameters_product:
for i in range(n_runs):
config = {
'run': i,
'data': {'type': 'images', 'train_size': train_size, 'test_size': 100},
'model': model,
'loss_function': loss_function,
'optimizer': 'adam',
'encoding': 'domain',
'positional_encoding': 0,
'epochs': 300,
'early_stop_patience': 60,
}
configs.append(config)
random.shuffle(configs)
configs = sorted(configs, key=lambda x: x['run'])
with open('configs_loss.json', 'w') as f:
json.dump(configs, f)
if __name__ == "__main__":
main()