-
Notifications
You must be signed in to change notification settings - Fork 6
/
graph.py
104 lines (87 loc) · 3.48 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import signal
import json
import logging
import argparse
import torch
from collections import defaultdict
import numpy as np
logging.basicConfig(level=logging.INFO, format='')
def graph(log,plot=True,substring=None):
graphs=defaultdict(lambda:{'iters':[], 'values':[]})
for index, entry in log.entries.items():
iteration = entry['iteration']
for metric, value in entry.items():
if metric!='iteration':
graphs[metric]['iters'].append(iteration)
graphs[metric]['values'].append(value)
print('summed')
skip=[]
for metric, data in graphs.items():
#print('{} max: {}, min {}'.format(metric,max(data['values']),min(data['values'])))
ndata = np.array(data['values'])
if ndata.dtype is not np.dtype(object):
maxV = ndata.max(axis=0)
minV = ndata.min(axis=0)
meanV = ndata.mean(axis=0)
print('{} max: {}, min: {}, mean: {}'.format(metric,maxV,minV,meanV))
else:
skip.append(metric)
if plot:
import matplotlib.pyplot as plt
i=1
for metric, data in graphs.items():
if metric in skip:
continue
if (substring is None and (metric[:3]=='avg' or metric[:3]=='val')) or (substring is not None and substring in metric):
#print('{} == {}? {}'.format(metric[:len(substring)],substring,metric[:len(substring)]==substring))
plt.figure(i)
i+=1
plt.plot(data['iters'], data['values'], '.-')
plt.xlabel('iterations')
plt.ylabel(metric)
plt.title(metric)
if i>15:
print('WARNING, too many windows, stopping')
break
plt.show()
else:
i=1
for metric, data in graphs.items():
if metric[:3]=='avg' or metric[:3]=='val':
print(metric)
print(data['values'])
if __name__ == '__main__':
logger = logging.getLogger()
parser = argparse.ArgumentParser(description='PyTorch Template')
parser.add_argument('-c', '--checkpoint', default='../..//Downloads/export.pth', type=str,
help='checkpoint file path (default: None)')
parser.add_argument('-p', '--plot', default=1, type=int,
help='plot (default: True)')
parser.add_argument('-o', '--only', default=None, type=str,
help='only stats with all these substrings (default: None)')
parser.add_argument('-e', '--extract', default=None, type=str,
help='instead of ploting, save a new file with only the log (default: None)')
parser.add_argument('-C', '--printconfig', default=False, type=bool,
help='print config (defaut False')
args = parser.parse_args()
assert args.checkpoint is not None
saved = torch.load(args.checkpoint,map_location=lambda storage, loc: storage)
log = saved['logger']
iteration = saved['iteration']
print('loaded iteration {}'.format(iteration))
if args.printconfig:
print(saved['config'])
exit()
saved=None
if args.extract is None:
graph(log,args.plot,args.only)
else:
new_save = {
'iteration': iteration,
'logger': log
}
new_file = args.extract #args.checkpoint+'.ex'
torch.save(new_save,new_file)
print('saved '+new_file)