-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvae.py
195 lines (169 loc) · 7.76 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from functools import partial
import itertools
from typing import Optional, Any
import numpy as np
import jax.numpy as jnp
from jax import random
from jax.util import safe_map
from flax import linen as nn
from einops import repeat, rearrange
from vae_helpers import (Conv1x1, Conv3x3, gaussian_sample, resize, parse_layer_string, pad_channels, get_width_settings,
gaussian_kl, Attention, recon_loss, sample, normalize, checkpoint, lecun_normal, has_attn, Block, EncBlock, identity)
import hps
map = safe_map
class Encoder(nn.Module):
H: hps.Hyperparams
@nn.compact
def __call__(self, x):
H = self.H
widths = get_width_settings(H.custom_width_str)
assert widths[str(int(x.shape[1]))] == H.width
x = Conv3x3(H.width, dtype=H.dtype)(x)
blocks = parse_layer_string(H.enc_blocks)
n_blocks = len(blocks)
activations = {}
activations[x.shape[1]] = x # Spatial dimension
for res, down_rate in blocks:
use_3x3 = res > 2 # Don't use 3x3s for 1x1, 2x2 patches
width = widths.get(str(res), H.width)
block = EncBlock(H, res, width, down_rate or 1, use_3x3, last_scale=np.sqrt(1 / n_blocks))
x = checkpoint(block.__call__, H, (x,))
new_res = x.shape[1]
new_width = widths.get(str(new_res), H.width)
x = x if (x.shape[3] == new_width) else pad_channels(x, new_width)
activations[new_res] = x
return activations
class DecBlock(nn.Module):
H: hps.Hyperparams
res: int
mixin: Optional[int]
n_blocks: int
bias_cond: bool
def setup(self):
H = self.H
width = self.width = get_width_settings(
H.custom_width_str).get(str(self.res), H.width)
use_3x3 = self.res > 2
cond_width = int(width * H.bottleneck_multiple)
self.zdim = H.zdim if H.zdim != -1 else width
self.enc = Block(H, cond_width, self.zdim * 2,
use_3x3=use_3x3)
self.prior = Block(H, cond_width, self.zdim * 2 + width,
use_3x3=use_3x3, last_scale=0.)
self.z_proj = Conv1x1(
width, kernel_init=lecun_normal(np.sqrt(1 / self.n_blocks)),
dtype=self.H.dtype)
self.resnet = Block(H, cond_width, width, residual=True,
use_3x3=use_3x3,
last_scale=np.sqrt(1 / self.n_blocks))
self.z_fn = lambda x: self.z_proj(x)
self.pre_layer = Attention(H) if has_attn(self.res, H) else identity
self.bias_x = None
self.expand_cond = self.mixin is not None
if self.bias_cond:
self.bias_x = self.param('bias_'+str(self.res), lambda key, shape: jnp.zeros(shape, dtype=H.dtype), (self.res, self.res, width))
def sample(self, x, acts, rng):
x = jnp.broadcast_to(x, acts.shape)
qm, qv = jnp.split(self.enc(jnp.concatenate([x, acts], 3)), 2, 3)
pm, pv, xpp = jnp.split(self.prior(x), [self.zdim, 2 * self.zdim], 3)
z = gaussian_sample(qm, jnp.exp(qv), rng)
kl = gaussian_kl(qm, pm, qv, pv)
#print('sample', jnp.isnan(kl.mean()), jnp.isfinite(kl.mean()))
return z, x + xpp, kl
def sample_uncond(self, x, rng, t=None, lvs=None):
pm, pv, xpp = jnp.split(self.prior(x), [self.zdim, 2 * self.zdim], 3)
return (gaussian_sample(pm, jnp.exp(pv) * (t or 1), rng)
if lvs is None else lvs, x + xpp)
def add_bias(self, x, batch):
bias = repeat(self.bias_x, '... -> b ...', b=batch)
if x is None:
return bias
else:
return x + bias
def forward(self, acts, rng, x=None):
if self.expand_cond:
# Assume width increases monotonically with depth
x = resize(x[..., :acts.shape[3]], (self.res, self.res))
if self.bias_cond:
x = self.add_bias(x, acts.shape[0])
x = self.pre_layer(x)
#print('call', jnp.isnan(x.mean()), jnp.isfinite(x.mean()))
z, x, kl = self.sample(x, acts, rng)
x = self.resnet(x + self.z_fn(z))
return z, x, kl
def __call__(self, acts, rng, x=None):
z, x, kl = self.forward(acts, rng, x=x)
return x, dict(kl=kl)
def get_latents(self, acts, rng, x=None):
z, x, kl = self.forward(acts, rng, x=x)
return x, dict(kl=kl, z=z)
def forward_uncond(self, rng, n, t=None, lvs=None, x=None):
if self.expand_cond:
# Assume width increases monotonically with depth
x = resize(x[..., :self.width], (self.res, self.res))
if self.bias_cond:
x = self.add_bias(x, n)
x = self.pre_layer(x)
z, x = self.sample_uncond(x, rng, t, lvs)
return self.resnet(x + self.z_fn(z))
class Decoder(nn.Module):
H: hps.Hyperparams
def setup(self):
H = self.H
resos = set()
dec_blocks = []
self.widths = get_width_settings(H.custom_width_str)
blocks = parse_layer_string(H.dec_blocks)
for idx, (res, mixin) in enumerate(blocks):
bias_cond = (mixin is not None and res <= H.no_bias_above) or (idx == 0)
dec_blocks.append(DecBlock(H, res, mixin, n_blocks=len(blocks), bias_cond=bias_cond))
resos.add(res)
self.dec_blocks = dec_blocks
self.gain = self.param('gain', lambda key, shape: jnp.ones(shape, dtype=self.H.dtype), (H.width,))
self.bias = self.param('bias', lambda key, shape: jnp.zeros(shape, dtype=self.H.dtype), (H.width,))
self.out_conv = Conv1x1(H.n_channels, dtype=self.H.dtype)
self.final_fn = lambda x: self.out_conv(x * self.gain + self.bias)
def __call__(self, activations, rng, get_latents=False):
stats = []
for idx, block in enumerate(self.dec_blocks):
rng, block_rng = random.split(rng)
acts = activations[block.res]
f = block.__call__ if not get_latents else block.get_latents
if idx == 0:
x, block_stats = checkpoint(f, self.H, (acts, block_rng))
else:
x, block_stats = checkpoint(f, self.H, (acts, block_rng, x))
stats.append(block_stats)
return self.final_fn(x), stats
def forward_uncond(self, n, rng, t=None):
x = None
for idx, block in enumerate(self.dec_blocks):
t_block = t[idx] if isinstance(t, list) else t
rng, block_rng = random.split(rng)
x = block.forward_uncond(block_rng, n, t_block, x=x)
return self.final_fn(x)
def forward_manual_latents(self, n, latents, rng, t=None):
x = None
for idx, (block, lvs) in enumerate(itertools.zip_longest(self.dec_blocks, latents)):
rng, block_rng = random.split(rng)
x = block.forward_uncond(block_rng, n, t, lvs, x=x)
return self.final_fn(x)
class VDVAE(nn.Module):
H: hps.Hyperparams
def setup(self):
self.encoder = Encoder(self.H)
self.decoder = Decoder(self.H)
def __call__(self, x, rng=None, **kwargs):
x_target = jnp.array(x) # is this clone?
rng, uncond_rng = random.split(rng)
px_z, stats = self.decoder(self.encoder(x), rng)
ndims = np.prod(x.shape[1:])
kl = sum((s['kl']/ ndims).sum((1, 2, 3)).mean() for s in stats)
loss = recon_loss(px_z, x_target)
return dict(loss=loss + kl, recon_loss=loss, kl=kl), None
def forward_get_latents(self, x, rng):
return self.decoder(self.encoder(x), rng, get_latents=True)[-1]
def forward_uncond_samples(self, size, rng, t=None):
return sample(self.decoder.forward_uncond(size, rng, t=t))
def forward_samples_set_latents(self, size, latents, rng, t=None):
return sample(self.decoder.forward_manual_latents(size, latents, rng, t=t))