-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
79 lines (60 loc) · 2.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import yaml
import models
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
def parse_config(path):
with open(path, 'r') as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
return cfg
def get_model(cnfg):
if cnfg['custom'] is True:
kwargs = cnfg['kwargs'] if 'kwargs' in cnfg else {}
return models.get_from_zoo(cnfg['arch'], kwargs)
else:
return models.get_tvision(cnfg['tvision']['name'], cnfg['tvision']['args'])
def get_scheduler(opt, cnfg, steps):
if cnfg['lr_scheduler'] == 'cyclic':
return torch.optim.lr_scheduler.CyclicLR(opt,
base_lr=cnfg['lr_min'],
max_lr=cnfg['lr_max'],
step_size_up=steps/2,
step_size_down=steps/2)
elif cnfg['lr_scheduler'] == 'step':
return torch.optim.lr_scheduler.StepLR(opt,
step_size=cnfg['step'],
gamma=cnfg['gamma'])
elif cnfg['lr_scheduler'] == 'multistep':
return torch.optim.lr_scheduler.MultiStepLR(opt,
milestones=cnfg['milestones'],
gamma=cnfg['gamma'])
else:
raise NotImplementedError(
"[ERROR] The selected scheduler is not implemented")
def save_model(model, cnf, epoch, path):
state = {
'epoch': epoch,
'cnf': cnf,
'arch': type(model).__name__,
'model': model.state_dict()
}
torch.save(state, path)
def get_lr(opt):
lrs = []
for param_group in opt.param_groups:
lrs.append(param_group["lr"])
return lrs
def adjust_lr(opt, sc, log, stp, do_log=True):
sc.step()
if do_log is True:
log_lr(log, opt, stp)
def log_lr(log, opt, stp):
lr = get_lr(opt)
log.log_lr(lr, stp)