-
Notifications
You must be signed in to change notification settings - Fork 3
/
logger.py
122 lines (100 loc) · 3.41 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import sys
import wandb
import os.path as osp
import torch
class CudaTimer:
def __init__(self, active = True):
if active:
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
self.start_timer = self._start_timer_active
self.end_timer = self._end_timer_active
else:
self.start_timer = self._start_timer_noop
self.end_timer = self._end_timer_noop
@classmethod
def initiate_timer(cls, active = True):
ctimer = cls(active=active)
ctimer.start_timer()
return ctimer
def _start_timer_active(self):
# Sync so we don't capture other work
torch.cuda.synchronize()
self.start.record()
def _end_timer_active(self):
self.end.record()
# Waits for everything to finish running
torch.cuda.synchronize()
return self.start.elapsed_time(self.end)*1e-3
def _start_timer_noop(self):
pass
def _end_timer_noop(self):
return 0.0
def setup_logger(name, save_dir, out_file='log.txt', on_stdout = True):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
if on_stdout:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
if save_dir:
fh = logging.FileHandler(os.path.join(save_dir, out_file))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def setup_noop_logger(name='noop'):
logger = logging.getLogger(name)
ch = logging.NullHandler()
logger.addHandler(ch)
return logger
#WANDB config here.
# os.environ["WANDB_API_KEY"] = "<Insert your key>"
project = "polygon-hgt"
entity = "polygon-hgt"
# def wandb_init_from_output_dir(cfg,output_dir, disable_wandb=False):
# mode = 'disabled' if disable_wandb else 'online'
# if resume:
# wandb_run = _wandb_from_checkpoint(cfg, checkpointer, mode)
# else:
# wandb_run = _wandb_from_config(cfg, timestamp, mode)
# checkpointer.wandb_id = wandb_run.id
# return wandb_run
def wandb_init(cfg, checkpointer, resume=False, timestamp = '', disable_wandb = False):
mode = 'disabled' if disable_wandb else 'online'
if resume:
wandb_run = _wandb_from_checkpoint(cfg, checkpointer, mode)
else:
wandb_run = _wandb_from_config(cfg, timestamp, mode)
checkpointer.wandb_id = wandb_run.id
return wandb_run
def _wandb_from_config(cfg, timestamp, mode):
kwargs = dict(
name = f'{cfg.EXPERIMENT.NAME}-{timestamp}',
group = cfg.EXPERIMENT.GROUP,
notes = cfg.EXPERIMENT.NOTES,
dir = osp.join(cfg.OUTPUT_DIR),
project = project,
entity = entity,
resume = 'never',
config = cfg,
mode = mode
)
return wandb.init(**kwargs)
def _wandb_from_checkpoint(cfg, checkpoint, mode):
assert checkpoint.wandb_id
kwargs = dict(
id = checkpoint.wandb_id,
dir = osp.join(cfg.OUTPUT_DIR),
project = project,
entity = entity,
resume = 'must',
config = cfg,
mode = mode
)
return wandb.init(**kwargs)