-
-
Notifications
You must be signed in to change notification settings - Fork 988
/
vae.py
256 lines (224 loc) · 8.99 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import numpy as np
import torch
import torch.nn as nn
import visdom
from utils.mnist_cached import MNISTCached as MNIST
from utils.mnist_cached import setup_data_loaders
from utils.vae_plots import mnist_test_tsne, plot_llk, plot_vae_samples
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the three linear transformations used
self.fc1 = nn.Linear(784, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, x):
# define the forward computation on the image x
# first shape the mini-batch to have pixels in the rightmost dimension
x = x.reshape(-1, 784)
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
# define the PyTorch module that parameterizes the
# observation likelihood p(x|z)
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the two linear transformations used
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 784)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.softplus(self.fc1(z))
# return the parameter for the output Bernoulli
# each is of size batch_size x 784
loc_img = torch.sigmoid(self.fc21(hidden))
return loc_img
# define a PyTorch module for the VAE
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder.forward(z)
# score against actual images (with relaxed Bernoulli values)
pyro.sample(
"obs",
dist.Bernoulli(loc_img, validate_args=False).to_event(1),
obs=x.reshape(-1, 784),
)
# return the loc so we can visualize it later
return loc_img
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder.forward(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
# sample in latent space
z = dist.Normal(z_loc, z_scale).sample()
# decode the image (note we don't sample in image space)
loc_img = self.decoder(z)
return loc_img
def main(args):
# clear param store
pyro.clear_param_store()
# setup MNIST data loaders
# train_loader, test_loader
train_loader, test_loader = setup_data_loaders(
MNIST, use_cuda=args.cuda, batch_size=256
)
# setup the VAE
vae = VAE(use_cuda=args.cuda)
# setup the optimizer
adam_args = {"lr": args.learning_rate}
optimizer = Adam(adam_args)
# setup the inference algorithm
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)
# setup visdom for visualization
if args.visdom_flag:
vis = visdom.Visdom()
train_elbo = {}
test_elbo = {}
# training loop
for epoch in range(args.num_epochs):
# initialize loss accumulator
epoch_loss = 0.0
# do a training epoch over each mini-batch x returned
# by the data loader
for x, _ in train_loader:
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x)
# report training diagnostics
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss / normalizer_train
train_elbo[epoch] = total_epoch_loss_train
print(
"[epoch %03d] average training loss: %.4f"
% (epoch, total_epoch_loss_train)
)
if epoch % args.test_frequency == 0:
# initialize loss accumulator
test_loss = 0.0
# compute the loss over the entire test set
for i, (x, _) in enumerate(test_loader):
# if on GPU put mini-batch into CUDA memory
if args.cuda:
x = x.cuda()
# compute ELBO estimate and accumulate loss
test_loss += svi.evaluate_loss(x)
# pick three random test images from the first mini-batch and
# visualize how well we're reconstructing them
if i == 0:
if args.visdom_flag:
plot_vae_samples(vae, vis)
reco_indices = np.random.randint(0, x.shape[0], 3)
for index in reco_indices:
test_img = x[index, :]
reco_img = vae.reconstruct_img(test_img)
vis.image(
test_img.reshape(28, 28).detach().cpu().numpy(),
opts={"caption": "test image"},
)
vis.image(
reco_img.reshape(28, 28).detach().cpu().numpy(),
opts={"caption": "reconstructed image"},
)
# report test diagnostics
normalizer_test = len(test_loader.dataset)
total_epoch_loss_test = test_loss / normalizer_test
test_elbo[epoch] = total_epoch_loss_test
print(
"[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)
)
plot_llk(train_elbo, test_elbo)
if epoch == args.tsne_iter:
mnist_test_tsne(vae=vae, test_loader=test_loader)
return vae
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument(
"-n", "--num-epochs", default=101, type=int, help="number of training epochs"
)
parser.add_argument(
"-tf",
"--test-frequency",
default=5,
type=int,
help="how often we evaluate the test set",
)
parser.add_argument(
"-lr", "--learning-rate", default=1.0e-3, type=float, help="learning rate"
)
parser.add_argument(
"--cuda", action="store_true", default=False, help="whether to use cuda"
)
parser.add_argument(
"--jit", action="store_true", default=False, help="whether to use PyTorch jit"
)
parser.add_argument(
"-visdom",
"--visdom_flag",
action="store_true",
help="Whether plotting in visdom is desired",
)
parser.add_argument(
"-i-tsne",
"--tsne_iter",
default=100,
type=int,
help="epoch when tsne visualization runs",
)
args = parser.parse_args()
model = main(args)