-
Notifications
You must be signed in to change notification settings - Fork 48
/
config.py
84 lines (61 loc) · 2.22 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
__author__ = 'Jiri Fajtl'
__email__ = 'ok1zjf@gmail.com'
__version__= '3.6'
__status__ = "Research"
__date__ = "1/12/2018"
__license__= "MIT License"
from torch.autograd import Variable
class HParameters:
def __init__(self):
self.verbose = False
self.use_cuda = True
self.cuda_device = 0
self.max_summary_length = 0.15
self.l2_req = 0.00001
self.lr_epochs = [0]
self.lr = [0.00005]
self.epochs_max = 300
self.train_batch_size = 1
self.output_dir = 'ex-10'
self.root = ''
self.datasets=['datasets/eccv16_dataset_summe_google_pool5.h5',
'datasets/eccv16_dataset_tvsum_google_pool5.h5',
'datasets/eccv16_dataset_ovp_google_pool5.h5',
'datasets/eccv16_dataset_youtube_google_pool5.h5']
self.splits = ['splits/tvsum_splits.json',
'splits/summe_splits.json']
self.splits += ['splits/tvsum_aug_splits.json',
'splits/summe_aug_splits.json']
return
def get_dataset_by_name(self, dataset_name):
for d in self.datasets:
if dataset_name in d:
return [d]
return None
def load_from_args(self, args):
for key in args:
val = args[key]
if val is not None:
if hasattr(self, key) and isinstance(getattr(self, key), list):
val = val.split()
setattr(self, key, val)
def __str__(self):
vars = [attr for attr in dir(self) if not callable(getattr(self,attr)) and not (attr.startswith("__") or attr.startswith("_"))]
info_str = ''
for i, var in enumerate(vars):
val = getattr(self, var)
if isinstance(val, Variable):
val = val.data.cpu().numpy().tolist()[0]
info_str += '['+str(i)+'] '+var+': '+str(val)+'\n'
return info_str
if __name__ == "__main__":
# Tests
hps = HParameters()
print(hps)
args = {'root': 'root_dir',
'datasets': 'set1,set2,set3',
'splits': 'split1, split2',
'new_param_float': 1.23456
}
hps.load_from_args(args)
print(hps)