-
Notifications
You must be signed in to change notification settings - Fork 203
/
latent_sde_lorenz.py
328 lines (273 loc) · 11.8 KB
/
latent_sde_lorenz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a latent SDE on data from a stochastic Lorenz attractor.
Reproduce the toy example in Section 7.2 of https://arxiv.org/pdf/2001.01328.pdf
To run this file, first run the following to install extra requirements:
pip install fire
To run, execute:
python -m examples.latent_sde_lorenz
"""
import logging
import os
from typing import Sequence
import fire
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from torch import nn
from torch import optim
from torch.distributions import Normal
import torchsde
class LinearScheduler(object):
def __init__(self, iters, maxval=1.0):
self._iters = max(1, iters)
self._val = maxval / self._iters
self._maxval = maxval
def step(self):
self._val = min(self._maxval, self._val + self._maxval / self._iters)
@property
def val(self):
return self._val
class StochasticLorenz(object):
"""Stochastic Lorenz attractor.
Used for simulating ground truth and obtaining noisy data.
Details described in Section 7.2 https://arxiv.org/pdf/2001.01328.pdf
Default a, b from https://openreview.net/pdf?id=HkzRQhR9YX
"""
noise_type = "diagonal"
sde_type = "ito"
def __init__(self, a: Sequence = (10., 28., 8 / 3), b: Sequence = (.1, .28, .3)):
super(StochasticLorenz, self).__init__()
self.a = a
self.b = b
def f(self, t, y):
x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
a1, a2, a3 = self.a
f1 = a1 * (x2 - x1)
f2 = a2 * x1 - x2 - x1 * x3
f3 = x1 * x2 - a3 * x3
return torch.cat([f1, f2, f3], dim=1)
def g(self, t, y):
x1, x2, x3 = torch.split(y, split_size_or_sections=(1, 1, 1), dim=1)
b1, b2, b3 = self.b
g1 = x1 * b1
g2 = x2 * b2
g3 = x3 * b3
return torch.cat([g1, g2, g3], dim=1)
@torch.no_grad()
def sample(self, x0, ts, noise_std, normalize):
"""Sample data for training. Store data normalization constants if necessary."""
xs = torchsde.sdeint(self, x0, ts)
if normalize:
mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1))
xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std)
return xs
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Encoder, self).__init__()
self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size)
self.lin = nn.Linear(hidden_size, output_size)
def forward(self, inp):
out, _ = self.gru(inp)
out = self.lin(out)
return out
class LatentSDE(nn.Module):
sde_type = "ito"
noise_type = "diagonal"
def __init__(self, data_size, latent_size, context_size, hidden_size):
super(LatentSDE, self).__init__()
# Encoder.
self.encoder = Encoder(input_size=data_size, hidden_size=hidden_size, output_size=context_size)
self.qz0_net = nn.Linear(context_size, latent_size + latent_size)
# Decoder.
self.f_net = nn.Sequential(
nn.Linear(latent_size + context_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, latent_size),
)
self.h_net = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, latent_size),
)
# This needs to be an element-wise function for the SDE to satisfy diagonal noise.
self.g_nets = nn.ModuleList(
[
nn.Sequential(
nn.Linear(1, hidden_size),
nn.Softplus(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
for _ in range(latent_size)
]
)
self.projector = nn.Linear(latent_size, data_size)
self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size))
self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size))
self._ctx = None
def contextualize(self, ctx):
self._ctx = ctx # A tuple of tensors of sizes (T,), (T, batch_size, d).
def f(self, t, y):
ts, ctx = self._ctx
i = min(torch.searchsorted(ts, t, right=True), len(ts) - 1)
return self.f_net(torch.cat((y, ctx[i]), dim=1))
def h(self, t, y):
return self.h_net(y)
def g(self, t, y): # Diagonal diffusion.
y = torch.split(y, split_size_or_sections=1, dim=1)
out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)]
return torch.cat(out, dim=1)
def forward(self, xs, ts, noise_std, adjoint=False, method="euler"):
# Contextualization is only needed for posterior inference.
ctx = self.encoder(torch.flip(xs, dims=(0,)))
ctx = torch.flip(ctx, dims=(0,))
self.contextualize((ts, ctx))
qz0_mean, qz0_logstd = self.qz0_net(ctx[0]).chunk(chunks=2, dim=1)
z0 = qz0_mean + qz0_logstd.exp() * torch.randn_like(qz0_mean)
if adjoint:
# Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`.
adjoint_params = (
(ctx,) +
tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters())
)
zs, log_ratio = torchsde.sdeint_adjoint(
self, z0, ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method)
else:
zs, log_ratio = torchsde.sdeint(self, z0, ts, dt=1e-2, logqp=True, method=method)
_xs = self.projector(zs)
xs_dist = Normal(loc=_xs, scale=noise_std)
log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean(dim=0)
qz0 = torch.distributions.Normal(loc=qz0_mean, scale=qz0_logstd.exp())
pz0 = torch.distributions.Normal(loc=self.pz0_mean, scale=self.pz0_logstd.exp())
logqp0 = torch.distributions.kl_divergence(qz0, pz0).sum(dim=1).mean(dim=0)
logqp_path = log_ratio.sum(dim=0).mean(dim=0)
return log_pxs, logqp0 + logqp_path
@torch.no_grad()
def sample(self, batch_size, ts, bm=None):
eps = torch.randn(size=(batch_size, *self.pz0_mean.shape[1:]), device=self.pz0_mean.device)
z0 = self.pz0_mean + self.pz0_logstd.exp() * eps
zs = torchsde.sdeint(self, z0, ts, names={'drift': 'h'}, dt=1e-3, bm=bm)
# Most of the times in ML, we don't sample the observation noise for visualization purposes.
_xs = self.projector(zs)
return _xs
def make_dataset(t0, t1, batch_size, noise_std, train_dir, device):
data_path = os.path.join(train_dir, 'lorenz_data.pth')
if os.path.exists(data_path):
data_dict = torch.load(data_path)
xs, ts = data_dict['xs'], data_dict['ts']
logging.warning(f'Loaded toy data at: {data_path}')
if xs.shape[1] != batch_size:
raise ValueError("Batch size has changed; please delete and regenerate the data.")
if ts[0] != t0 or ts[-1] != t1:
raise ValueError("Times interval [t0, t1] has changed; please delete and regenerate the data.")
else:
_y0 = torch.randn(batch_size, 3, device=device)
ts = torch.linspace(t0, t1, steps=100, device=device)
xs = StochasticLorenz().sample(_y0, ts, noise_std, normalize=True)
os.makedirs(os.path.dirname(data_path), exist_ok=True)
torch.save({'xs': xs, 'ts': ts}, data_path)
logging.warning(f'Stored toy data at: {data_path}')
return xs, ts
def vis(xs, ts, latent_sde, bm_vis, img_path, num_samples=10):
fig = plt.figure(figsize=(20, 9))
gs = gridspec.GridSpec(1, 2)
ax00 = fig.add_subplot(gs[0, 0], projection='3d')
ax01 = fig.add_subplot(gs[0, 1], projection='3d')
# Left plot: data.
z1, z2, z3 = np.split(xs.cpu().numpy(), indices_or_sections=3, axis=-1)
[ax00.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)]
ax00.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x')
ax00.set_yticklabels([])
ax00.set_xticklabels([])
ax00.set_zticklabels([])
ax00.set_xlabel('$z_1$', labelpad=0., fontsize=16)
ax00.set_ylabel('$z_2$', labelpad=.5, fontsize=16)
ax00.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16)
ax00.set_title('Data', fontsize=20)
xlim = ax00.get_xlim()
ylim = ax00.get_ylim()
zlim = ax00.get_zlim()
# Right plot: samples from learned model.
xs = latent_sde.sample(batch_size=xs.size(1), ts=ts, bm=bm_vis).cpu().numpy()
z1, z2, z3 = np.split(xs, indices_or_sections=3, axis=-1)
[ax01.plot(z1[:, i, 0], z2[:, i, 0], z3[:, i, 0]) for i in range(num_samples)]
ax01.scatter(z1[0, :num_samples, 0], z2[0, :num_samples, 0], z3[0, :10, 0], marker='x')
ax01.set_yticklabels([])
ax01.set_xticklabels([])
ax01.set_zticklabels([])
ax01.set_xlabel('$z_1$', labelpad=0., fontsize=16)
ax01.set_ylabel('$z_2$', labelpad=.5, fontsize=16)
ax01.set_zlabel('$z_3$', labelpad=0., horizontalalignment='center', fontsize=16)
ax01.set_title('Samples', fontsize=20)
ax01.set_xlim(xlim)
ax01.set_ylim(ylim)
ax01.set_zlim(zlim)
plt.savefig(img_path)
plt.close()
def main(
batch_size=1024,
latent_size=4,
context_size=64,
hidden_size=128,
lr_init=1e-2,
t0=0.,
t1=2.,
lr_gamma=0.997,
num_iters=5000,
kl_anneal_iters=1000,
pause_every=50,
noise_std=0.01,
adjoint=False,
train_dir='./dump/lorenz/',
method="euler",
):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xs, ts = make_dataset(t0=t0, t1=t1, batch_size=batch_size, noise_std=noise_std, train_dir=train_dir, device=device)
latent_sde = LatentSDE(
data_size=3,
latent_size=latent_size,
context_size=context_size,
hidden_size=hidden_size,
).to(device)
optimizer = optim.Adam(params=latent_sde.parameters(), lr=lr_init)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_gamma)
kl_scheduler = LinearScheduler(iters=kl_anneal_iters)
# Fix the same Brownian motion for visualization.
bm_vis = torchsde.BrownianInterval(
t0=t0, t1=t1, size=(batch_size, latent_size,), device=device, levy_area_approximation="space-time")
for global_step in tqdm.tqdm(range(1, num_iters + 1)):
latent_sde.zero_grad()
log_pxs, log_ratio = latent_sde(xs, ts, noise_std, adjoint, method)
loss = -log_pxs + log_ratio * kl_scheduler.val
loss.backward()
optimizer.step()
scheduler.step()
kl_scheduler.step()
if global_step % pause_every == 0:
lr_now = optimizer.param_groups[0]['lr']
logging.warning(
f'global_step: {global_step:06d}, lr: {lr_now:.5f}, '
f'log_pxs: {log_pxs:.4f}, log_ratio: {log_ratio:.4f} loss: {loss:.4f}, kl_coeff: {kl_scheduler.val:.4f}'
)
img_path = os.path.join(train_dir, f'global_step_{global_step:06d}.pdf')
vis(xs, ts, latent_sde, bm_vis, img_path)
if __name__ == "__main__":
fire.Fire(main)