-
Notifications
You must be signed in to change notification settings - Fork 42
/
loss_plot.py
65 lines (51 loc) · 1.72 KB
/
loss_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import spline
#Plotting loss
with open(log_dir+"\\training_log_corrected.txt") as f:
log_data = f.readlines()
train_steps = []
disc_real_loss = []
disc_fake_loss = []
GAN_loss = []
for log in log_data:
train_steps.append(int(log.split(' ')[1]))
disc_real_loss.append(float(log.split(' ')[5]))
disc_fake_loss.append(float(log.split(' ')[8]))
GAN_loss.append(float(log.split(' ')[11][0:-1]))
#takes too long in my comp; didn't use
#x_smooth = np.linspace(np.array(train_steps).min(), np.array(train_steps).max(), len(train_steps)/10)
#y_smooth = spline(train_steps, disc_real_loss, x_smooth)
#plt.plot(x_smooth,y_smooth,label="Discriminator real loss")
"""
#all 3 in one plot
plt.plot(train_steps,disc_real_loss,label="Discriminator real loss")
plt.plot(train_steps,disc_fake_loss,label="Discriminator fake loss")
plt.plot(train_steps,GAN_loss,label="Generator loss")
plt.xlabel('Train steps')
plt.ylabel('Loss')
plt.show()
"""
#real vs fake loss plot
fig = plt.figure()
ax = fig.add_subplot(111)
lns1 = ax.plot(train_steps, disc_real_loss, '-', label="Discriminator real loss")
lns2 = ax.plot(train_steps, disc_fake_loss, '-', label="Discriminator fake loss")
lns = lns1+lns2
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)
ax.grid()
ax.set_xlabel("Train steps")
ax.set_ylabel("Loss")
plt.show()
#gen loss plot
fig = plt.figure()
ax = fig.add_subplot(111)
lns3 = ax.plot(train_steps, GAN_loss, '-', label="Generator loss")
lns = lns3
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)
ax.grid()
ax.set_xlabel("Train steps")
ax.set_ylabel("Loss")
plt.show()