-
Notifications
You must be signed in to change notification settings - Fork 18
/
training.py
220 lines (184 loc) · 8.58 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Pytorch implementation of Self-Supervised GAN
Reference: "Self-Supervised GANs via Auxiliary Rotation Loss"
Authors: Ting Chen,
Xiaohua Zhai,
Marvin Ritter,
Mario Lucic and
Neil Houlsby
https://arxiv.org/abs/1811.11212 CVPR 2019.
Script Author: Vandit Jain. Github:vandit15
"""
import imageio
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
class Trainer():
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer,
weight_rotation_loss_d, weight_rotation_loss_g, gp_weight=10, critic_iterations=5, print_every=50,
use_cuda=False):
self.G = generator
self.G_opt = gen_optimizer
self.D = discriminator
self.D_opt = dis_optimizer
self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []}
self.num_steps = 0
self.use_cuda = use_cuda
self.gp_weight = gp_weight
self.critic_iterations = critic_iterations
self.print_every = print_every
self.weight_rotation_loss_d = weight_rotation_loss_d
self.weight_rotation_loss_g = weight_rotation_loss_g
if self.use_cuda:
self.G.cuda()
self.D.cuda()
def _critic_train_iteration(self, data, generated_data, batch_size):
""" """
# Calculate probabilities on real and generated data
data = Variable(data)
if self.use_cuda:
data = data.cuda()
_, d_real_pro_logits, d_real_rot_logits, d_real_rot_prob = self.D(data)
_, g_fake_pro_logits, g_fake_rot_logits, g_fake_rot_prob = self.D(generated_data)
# Get gradient penalty
gradient_penalty = self._gradient_penalty(data, generated_data)
self.losses['GP'].append(gradient_penalty.data)
# Create total loss and optimize
self.D_opt.zero_grad()
d_loss = torch.sum(g_fake_pro_logits) - torch.sum(d_real_pro_logits) + gradient_penalty
# Add auxiiary rotation loss
rot_labels = torch.zeros(4*batch_size).cuda()
for i in range(4*batch_size):
if i < batch_size:
rot_labels[i] = 0
elif i < 2*batch_size:
rot_labels[i] = 1
elif i < 3*batch_size:
rot_labels[i] = 2
else:
rot_labels[i] = 3
rot_labels = F.one_hot(rot_labels.to(torch.int64), 4).float()
d_real_class_loss = torch.sum(F.binary_cross_entropy_with_logits(
input = d_real_rot_logits,
target = rot_labels))
d_loss += self.weight_rotation_loss_d * d_real_class_loss
d_loss.backward(retain_graph=True)
self.D_opt.step()
# Record loss
self.losses['D'].append(d_loss.data)
def _generator_train_iteration(self, generated_data, batch_size):
""" """
self.G_opt.zero_grad()
# Calculate loss and optimize
_, g_fake_pro_logits, g_fake_rot_logits, g_fake_rot_prob = self.D(generated_data)
g_loss = - torch.sum(g_fake_pro_logits)
# add auxiliary rotation loss
rot_labels = torch.zeros(4*batch_size,).cuda()
for i in range(4*batch_size):
if i < batch_size:
rot_labels[i] = 0
elif i < 2*batch_size:
rot_labels[i] = 1
elif i < 3*batch_size:
rot_labels[i] = 2
else:
rot_labels[i] = 3
rot_labels = F.one_hot(rot_labels.to(torch.int64), 4).float()
g_fake_class_loss = torch.sum(F.binary_cross_entropy_with_logits(
input = g_fake_rot_logits,
target = rot_labels))
g_loss += self.weight_rotation_loss_g * g_fake_class_loss
g_loss.backward(retain_graph=True)
self.G_opt.step()
# Record loss
self.losses['G'].append(g_loss.data)
def _gradient_penalty(self, real_data, generated_data):
batch_size = real_data.size()[0]
# Calculate interpolation
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
if self.use_cuda:
alpha = alpha.cuda()
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
interpolated = Variable(interpolated, requires_grad=True)
if self.use_cuda:
interpolated = interpolated.cuda()
# Calculate probability of interpolated examples
_, prob_interpolated, _, _ = self.D(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones(
prob_interpolated.size()),
create_graph=True, retain_graph=True)[0]
# Gradients have shape (batch_size, num_channels, img_width, img_height),
# so flatten to easily take norm per example in batch
gradients = gradients.view(batch_size, -1)
self.losses['gradient_norm'].append(gradients.norm(2, dim=1).sum().data)
# Derivatives of the gradient close to 0 can cause problems because of
# the square root, so manually calculate norm and add epsilon
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
# Return gradient penalty
return self.gp_weight * ((gradients_norm - 1) ** 2).mean()
def _train_epoch(self, data_loader):
for i, data in enumerate(data_loader):
# Get generated data
data = data[0]
batch_size = data.size()[0]
generated_data = self.sample_generator(batch_size)
x = generated_data
x_90 = x.transpose(2,3)
x_180 = x.flip(2,3)
x_270 = x.transpose(2,3).flip(2,3)
generated_data = torch.cat((x, x_90, x_180, x_270),0)
x = data
x_90 = x.transpose(2,3)
x_180 = x.flip(2,3)
x_270 = x.transpose(2,3).flip(2,3)
data = torch.cat((x,x_90,x_180,x_270),0)
self.num_steps += 1
self._critic_train_iteration(data, generated_data, batch_size)
# Only update generator every |critic_iterations| iterations
if self.num_steps % self.critic_iterations == 0:
self._generator_train_iteration(generated_data, batch_size)
if i % self.print_every == 0:
print("Iteration {}".format(i + 1))
print("D: {}".format(self.losses['D'][-1]))
print("GP: {}".format(self.losses['GP'][-1]))
print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
if self.num_steps > self.critic_iterations:
print("G: {}".format(self.losses['G'][-1]))
def train(self, data_loader, epochs, save_training_gif=True):
if save_training_gif:
# Fix latents to see how image generation improves during training
fixed_latents = Variable(self.G.sample_latent(64))
if self.use_cuda:
fixed_latents = fixed_latents.cuda()
training_progress_images = []
for epoch in range(epochs):
print("\nEpoch {}".format(epoch + 1))
self._train_epoch(data_loader)
if save_training_gif:
# Generate batch of images and convert to grid
img_grid = make_grid(self.G(fixed_latents).cpu().data)
# Convert to numpy and transpose axes to fit imageio convention
# i.e. (width, height, channels)
img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
# Add image grid to training progress
training_progress_images.append(img_grid)
if save_training_gif:
imageio.mimsave('./training_{}_epochs.gif'.format(epochs),
training_progress_images)
def sample_generator(self, num_samples):
latent_samples = Variable(self.G.sample_latent(num_samples))
if self.use_cuda:
latent_samples = latent_samples.cuda()
generated_data = self.G(latent_samples)
return generated_data
def sample(self, num_samples):
generated_data = self.sample_generator(num_samples)
# Remove color channel
return generated_data.data.cpu().numpy()[:, 0, :, :]