-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
124 lines (107 loc) · 3.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import yaml
import argparse
import importlib
import os.path as osp
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def dict2namespace(config):
if isinstance(config, argparse.Namespace):
return config
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def load_imf(log_path, config_fpath=None, ckpt_fpath=None,
epoch=None, verbose=False,
return_trainer=False, return_cfg=False):
# Load configuration
if config_fpath is None:
config_fpath = osp.join(log_path, "config", "config.yaml")
with open(config_fpath) as f:
cfg = dict2namespace(yaml.load(f, Loader=yaml.Loader))
cfg.save_dir = "logs"
# Load pretrained checkpoints
ep2file = {}
last_file, last_ep = osp.join(log_path, "latest.pt"), -1
if ckpt_fpath is not None:
last_file = ckpt_fpath
else:
ckpt_path = osp.join(log_path, "checkpoints")
if osp.isdir(ckpt_path):
for f in os.listdir(ckpt_path):
if not f.endswith(".pt"):
continue
ep = int(f.split("_")[1])
if verbose:
print(ep, f)
ep2file[ep] = osp.join(ckpt_path, f)
if ep > last_ep:
last_ep = ep
last_file = osp.join(ckpt_path, f)
if epoch is not None:
last_file = ep2file[epoch]
print(last_file)
trainer_lib = importlib.import_module(cfg.trainer.type)
trainer = trainer_lib.Trainer(cfg, None)
trainer.resume(last_file)
if return_trainer:
return trainer, cfg
else:
imf = trainer.net
del trainer
return imf, cfg
def parse_hparams(hparam_lst):
print("=" * 80)
print("Parsing:", hparam_lst)
out_str = ""
out = {}
for i, hparam in enumerate(hparam_lst):
hparam = hparam.strip()
k, v = hparam.split("=")[:2]
k = k.strip()
v = v.strip()
print(k, v)
out[k] = v
out_str += "%s=%s_" % (k, v.replace("/", "-"))
print(out)
print(out_str)
print("=" * 80)
return out, out_str
def update_cfg_with_hparam(cfg, k, v):
k_path = k.split(".")
cfg_curr = cfg
for k_curr in k_path[:-1]:
assert hasattr(cfg_curr, k_curr), "%s not in %s" % (k_curr, cfg_curr)
cfg_curr = getattr(cfg_curr, k_curr)
k_final = k_path[-1]
assert hasattr(cfg_curr, k_final), \
"Final: %s not in %s" % (k_final, cfg_curr)
v_type = type(getattr(cfg_curr, k_final))
setattr(cfg_curr, k_final, v_type(v))
def update_cfg_hparam_lst(cfg, hparam_lst):
hparam_dict, hparam_str = parse_hparams(hparam_lst)
for k, v in hparam_dict.items():
update_cfg_with_hparam(cfg, k, v)
return cfg, hparam_str