-
Notifications
You must be signed in to change notification settings - Fork 1
/
paper_trainings.py
69 lines (62 loc) · 2.44 KB
/
paper_trainings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import pandas as pd
from mdeq_lib.evaluate.cls_valid import evaluate_classifier
from mdeq_lib.training.cls_train import train_classifier
def parse_n_refine(n_refine):
try:
return int(n_refine)
except ValueError:
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Train CIFAR MDEQ models with different techniques.')
parser.add_argument('--n_gpus', '-g', default=4,
help='The number of GPUs to use.')
parser.add_argument('--dataset', '-d', default='cifar',
help='The dataset to chose between cifar and imagenet.'
'Defaults to cifar.')
parser.add_argument('--n_runs', '-n', default=5,
help='Number of seeds to use for the figure. Defaults to 5.')
parser.add_argument('--refines', '-r', default='0,1,2,5,7,10,None',
help='Number of steps to consider for backward iterations, comma-separated. '
'Use None to indicate the default number of steps. Defaults to 0,1,2,5,7,10,None')
args = parser.parse_args()
n_runs = int(args.n_runs)
n_gpus = int(args.n_gpus)
dataset = args.dataset
n_epochs = 220 if dataset == 'cifar' else 100
n_refines = [parse_n_refine(n_refine.strip()) for n_refine in args.refines.split(',')]
base_params = dict(
model_size='LARGE' if dataset == 'cifar' else 'SMALL',
dataset=dataset,
n_gpus=n_gpus,
n_epochs=n_epochs,
)
parameters = []
for i_run in range(n_runs):
base_params.update(seed=i_run)
for n_refine in n_refines:
base_params.update(n_refine=n_refine)
if n_refine != 0:
parameters += [
dict(**base_params),
]
if n_refine is not None:
parameters += [
dict(shine=True, refine=True, **base_params),
dict(fpn=True, refine=True, **base_params),
]
res_data = []
for params in parameters:
train_classifier(**params)
eval_params = dict(**params)
eval_params.pop('n_epochs')
metrics_names, eval_res = evaluate_classifier(**eval_params)
res_data.append(
{
'top1': eval_res,
**params
}
)
df_res = pd.DataFrame(res_data)
df_res.to_csv(f'{dataset}_mdeq_results.csv')