-
Notifications
You must be signed in to change notification settings - Fork 36
/
utils.py
104 lines (77 loc) · 3.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import glob
import torch
import shutil
import numpy as np
import torch.nn as nn
from logger_config import logger
class AttrDict:
pass
def save_checkpoint(state: dict, is_best: bool, filename: str):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.dirname(filename) + '/model_best.mdl')
shutil.copyfile(filename, os.path.dirname(filename) + '/model_last.mdl')
def delete_old_ckt(path_pattern: str, keep=5):
files = sorted(glob.glob(path_pattern), key=os.path.getmtime, reverse=True)
for f in files[keep:]:
logger.info('Delete old checkpoint {}'.format(f))
os.system('rm -f {}'.format(f))
def report_num_trainable_parameters(model: torch.nn.Module) -> int:
assert isinstance(model, torch.nn.Module), 'Argument must be nn.Module'
num_parameters = 0
for name, p in model.named_parameters():
if p.requires_grad:
num_parameters += np.prod(list(p.size()))
logger.info('{}: {}'.format(name, np.prod(list(p.size()))))
logger.info('Number of parameters: {}M'.format(num_parameters // 10**6))
return num_parameters
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, "module") else model
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch: int):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
logger.info('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches: int) -> str:
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'