-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
99 lines (92 loc) · 2.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import cv2
import numpy as np
import plotly
from plotly.graph_objs import Scatter
from plotly.graph_objs.scatter import Line
# Plots min, max and mean + standard deviation bars of a population over time
def lineplot(xs, ys_population, title, path="", xaxis="episode"):
max_colour, mean_colour, std_colour, transparent = (
"rgb(0, 132, 180)",
"rgb(0, 172, 237)",
"rgba(29, 202, 255, 0.2)",
"rgba(0, 0, 0, 0)",
)
if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple):
ys = np.asarray(ys_population, dtype=np.float32)
ys_min, ys_max, ys_mean, ys_std, ys_median = (
ys.min(1),
ys.max(1),
ys.mean(1),
ys.std(1),
np.median(ys, 1),
)
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
trace_max = Scatter(
x=xs, y=ys_max, line=Line(color=max_colour, dash="dash"), name="Max"
)
trace_upper = Scatter(
x=xs,
y=ys_upper,
line=Line(color=transparent),
name="+1 Std. Dev.",
showlegend=False,
)
trace_mean = Scatter(
x=xs,
y=ys_mean,
fill="tonexty",
fillcolor=std_colour,
line=Line(color=mean_colour),
name="Mean",
)
trace_lower = Scatter(
x=xs,
y=ys_lower,
fill="tonexty",
fillcolor=std_colour,
line=Line(color=transparent),
name="-1 Std. Dev.",
showlegend=False,
)
trace_min = Scatter(
x=xs, y=ys_min, line=Line(color=max_colour, dash="dash"), name="Min"
)
trace_median = Scatter(
x=xs, y=ys_median, line=Line(color=max_colour), name="Median"
)
data = [
trace_upper,
trace_mean,
trace_lower,
trace_min,
trace_max,
trace_median,
]
else:
data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))]
plotly.offline.plot(
{
"data": data,
"layout": dict(title=title, xaxis={"title": xaxis}, yaxis={"title": title}),
},
filename=os.path.join(path, title + ".html"),
auto_open=False,
)
def write_video(frames, title, path=""):
frames = (
np.multiply(np.stack(frames, axis=0).transpose(0, 2, 3, 1), 255)
.clip(0, 255)
.astype(np.uint8)[:, :, :, ::-1]
) # VideoWrite expects H x W x C in BGR
_, H, W, _ = frames.shape
writer = cv2.VideoWriter(
os.path.join(path, "%s.mp4" % title),
cv2.VideoWriter_fourcc(*"mp4v"),
30.0,
(W, H),
True,
)
for frame in frames:
writer.write(frame)
writer.release()