-
Notifications
You must be signed in to change notification settings - Fork 2
/
exp_utils.py
executable file
·43 lines (32 loc) · 1.28 KB
/
exp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import functools
import os, shutil
import numpy as np
import torch
def logging(s, log_path, print_=True, log_=True):
print(s)
# if print_:
# print(s)
if log_:
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
if debug:
print('Debug Mode : no experiment dir created')
return functools.partial(logging, log_path=None, log_=False)
if os.path.exists(dir_path):
print("Path {} exists. Remove and remake.".format(dir_path))
shutil.rmtree(dir_path)
os.makedirs(dir_path)
print('Experiment dir : {}'.format(dir_path))
if scripts_to_save is not None:
script_path = os.path.join(dir_path, 'scripts')
if not os.path.exists(script_path):
os.makedirs(script_path)
for script in scripts_to_save:
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
def save_checkpoint(model, optimizer, path, epoch):
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))