-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
51 lines (41 loc) · 1.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import os
import glob
def load_latest_ckpt(exp_name, epoch_id=-1):
def get_epoch_num(path):
base = os.path.basename(path)
base = base.split('-')[0][6:]
return int(base)
if epoch_id == -1:
all_ckpts = glob.glob(f"checkpoints/{exp_name}/epoch*")
else:
all_ckpts = glob.glob(f"checkpoints/{exp_name}/epoch={epoch_id}-*")
ckpt_list = sorted(
all_ckpts,
key=lambda x: get_epoch_num(x),
reverse=True
)
assert len(ckpt_list) > 0, f"no checkpoint found for {exp_name}, epoch={epoch_id}"
return ckpt_list[0], get_epoch_num(ckpt_list[0])
def move_to_cuda(sample):
def _move_to_cuda(tensor):
return tensor.cuda()
return apply_to_sample(_move_to_cuda, sample)
def apply_to_sample(f, sample):
if len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, dict):
r = {key: _apply(value) for key, value in x.items()}
return r
# return {
# key: _apply(value)
# for key, value in x.items()
# }
elif isinstance(x, list):
return [_apply(x) for x in x]
else:
return x
return _apply(sample)