-
Notifications
You must be signed in to change notification settings - Fork 3
/
loggers.py
99 lines (76 loc) · 2.9 KB
/
loggers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys
import time
class TextLogger(object):
"""Writes stream output to external text file.
Args:
filename (str): the file to write stream output
stream: the stream to read from. Default: sys.stdout
"""
def __init__(self, filename, stream=sys.stdout):
self.terminal = stream
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.terminal.close()
self.log.close()
class CompleteLogger:
"""
A useful logger that
- writes outputs to files and displays them on the console at the same time.
- manages the directory of checkpoints and debugging images.
Args:
root (str): the root directory of logger
phase (str): the phase of training.
"""
def __init__(self, root, log_name='train'):
self.root = root
self.log_name = log_name
self.visualize_directory = os.path.join(self.root, "visualize")
self.checkpoint_directory = os.path.join(self.root, "checkpoints")
self.epoch = 0
os.makedirs(self.root, exist_ok=True)
os.makedirs(self.visualize_directory, exist_ok=True)
os.makedirs(self.checkpoint_directory, exist_ok=True)
# redirect std out
now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
log_filename = os.path.join(self.root, "{}_{}.txt".format(self.log_name, now))
if os.path.exists(log_filename):
os.remove(log_filename)
self.logger = TextLogger(log_filename)
sys.stdout = self.logger
sys.stderr = self.logger
def set_epoch(self, epoch):
"""Set the epoch number. Please use it during training."""
# os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True)
self.epoch = epoch
def _get_phase_or_epoch(self):
if self.phase == 'train':
return str(self.epoch)
else:
return self.phase
def get_image_path(self, filename: str):
"""
Get the full image path for a specific filename
"""
return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename)
def get_checkpoint_path(self, name=None):
"""
Get the full checkpoint path.
Args:
name (optional): the filename (without file extension) to save checkpoint.
If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``.
Otherwise, will be saved to ``{phase}.pth``.
"""
if name is None:
name = self._get_phase_or_epoch()
name = str(name)
return os.path.join(self.checkpoint_directory, name + ".pth")
def close(self):
self.logger.close()