-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot.py
84 lines (64 loc) · 2.47 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import visdom
import numpy as np
import matplotlib.pyplot as plt
class EmptyPlot:
def __init__(self):
pass
def update(self, title, x_val, y_val, update={'flag':False, 'val':None}):
# do nothing
pass
class VisdomPlot(EmptyPlot):
def __init__(self, env_title, types, titles, xlabels, ylabels, legends):
super().__init__()
self.Handle = visdom.Visdom(server="http://127.0.0.1", port=8097, env=env_title)
self.Types = {title: type_ for title,type_ in zip(titles, types)}
self.XLabels = {title: xlabel for title, xlabel in zip(titles, xlabels)}
self.YLabels = {title: ylabel for title, ylabel in zip(titles, ylabels)}
self.Legends = {title: legend for title, legend in zip(titles, legends)}
def update(self, title, x_val, y_val, update={'flag':False, 'val':None}):
if not update['flag']:
update_flag = None if x_val==0 else 'append' # 第一个iteration为1
else:
update_flag = update['val']
if update_flag is None:
print('[VisdomPlot] clear visdom plot')
y_val = np.array(y_val)
y_size = y_val.shape[1]
x_val = np.ones((1, y_size)) * x_val
plot_func = self.getType(self.Types[title])
plot_func(X=x_val, Y=y_val, win=title,
opts=dict(
legend=self.Legends[title],
title=title,
xlabel=self.XLabels[title],
ylabel=self.YLabels[title],
),
update= update_flag)
def getType(self, t):
if t == 'line':
return self.Handle.line
else:
raise NotImplementedError('[VisdomPlot] Not supported type: %s'%t)
def plotLine(points_list,
label_list,
title='',
gap=100,
color_list=['red'],
style_list=['-'],
grid=True,
xlim=None,
ylim=None,
save_path=None):
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.grid(grid, axis='y', color='black', linestyle='--')
plt.title(title)
for points,color,style,label in zip(points_list, color_list, style_list, label_list):
x = [i * gap for i in range(len(points))]
plt.plot(x, points, color=color, linestyle=style, label=label)
plt.legend()
if save_path is not None:
plt.savefig(save_path)
plt.show()