forked from elijahcole/sinr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_and_evaluate_models.py
68 lines (56 loc) · 1.96 KB
/
train_and_evaluate_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
import torch
import train
import eval
train_params = {}
train_params['experiment_name'] = 'unconditional' # This will be the name of the directory where results for this run are saved.
'''
species_set
- Which set of species to train on.
- Valid values: 'all', 'snt_birds'
'''
train_params['species_set'] = 'all'
'''
hard_cap_num_per_class
- Maximum number of examples per class to use for training.
- Valid values: positive integers or -1 (indicating no cap).
'''
train_params['hard_cap_num_per_class'] = 1000
'''
num_aux_species
- Number of random additional species to add.
- Valid values: Nonnegative integers. Should be zero if params['species_set'] == 'all'.
'''
train_params['num_aux_species'] = 0
'''
input_enc
- Type of inputs to use for training.
- Valid values: 'sin_cos', 'env', 'sin_cos_env'
'''
train_params['input_enc'] = 'sin_cos'
'''
loss
- Which loss to use for training.
- Valid values: 'an_full', 'an_slds', 'an_ssdl', 'an_full_me', 'an_slds_me', 'an_ssdl_me'
'''
train_params['loss'] = 'an_full_uncondititonal'
train_params['model'] = 'ResidualFCNetUnconditional'
# train:
train.launch_training_run(train_params)
# evaluate:
for eval_type in ['snt', 'iucn', 'geo_prior', 'geo_feature']:
eval_params = {}
eval_params['exp_base'] = './experiments'
eval_params['experiment_name'] = train_params['experiment_name']
eval_params['eval_type'] = eval_type
if eval_type == 'iucn':
eval_params['device'] = torch.device('cpu') # for memory reasons
cur_results = eval.launch_eval_run(eval_params)
np.save(os.path.join(eval_params['exp_base'], train_params['experiment_name'], f'results_{eval_type}.npy'), cur_results)
'''
Note that train_params and eval_params do not contain all of the parameters of interest. Instead,
there are default parameter sets for training and evaluation (which can be found in setup.py).
In this script we create dictionaries of key-value pairs that are used to override the defaults
as needed.
'''