-
Notifications
You must be signed in to change notification settings - Fork 2
/
Autoencoder.py
110 lines (96 loc) · 2.32 KB
/
Autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.nn.functional as F
from Layers import *
from Base import BaseModel, BaseAutoencoder
from Encoders import *
from Decoders import *
class Autoencoder(BaseAutoencoder):
def __init__(self,
encoder,
decoder,
irmae=False,
vae=False,
vq_vae=False,
verbose=False,
**kwargs):
super(Autoencoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.irmae = irmae
self.vae = vae
self.vq_vae = vq_vae
self.verbose = verbose
self.encoder.verbose = verbose
self.decoder.verbose = verbose
self._latent_size = self.encoder._latent_size
assert self.vae == encoder.vae, 'Encoder must be set to vae=True'
self.latent_dim = self.encoder.latent_dim
if self.irmae:
self.mlp = MLP(vae=self.vae, **kwargs['mlp_params'])
assert self.vae == self.mlp.vae, 'MLP must be set to vae=True'
if self.vq_vae:
self.vq_embedding = VQEmbeddingEMA()
def encode(self, x):
z = self.encoder(x)
return z
def latent(self, z):
if self.irmae:
z = self.mlp(z)
if self.vae:
mu = z[:, :self.latent_dim]
logvar = z[:, self.latent_dim:]
z = self.reparametrization(mu, logvar)
return z, mu, logvar
elif self.vq_vae:
z, loss, perplexity = self.vq_embedding(z)
return z, loss, perplexity
else:
return z
def decode(self, z, *args):
x = self.decoder(z, *args)
return x
def encode_to_latent(self, x):
z = self.encode(x)
if type(z) == tuple:
phase, z = z
else:
phase = None
latents = self.latent(z)
if self.vae:
z, mu, logvar = latents
elif self.vq_vae:
z, loss, perplexity = latents
else:
z = latents
return z
def forward(self, x):
z = self.encode(x)
if type(z) == tuple:
phase, z = z
else:
phase = None
latents = self.latent(z)
if self.vae:
z, mu, logvar = latents
elif self.vq_vae:
z, loss, perplexity = latents
else:
z = latents
if phase is not None:
x_hat = self.decode(z, phase)
else:
x_hat = self.decode(z)
if self.vae:
return x_hat, x, mu, logvar, z
elif self.vq_vae:
return x_hat, x, z, loss, perplexity
else:
return x_hat
def sample(self, n_samples, current_device):
z = torch.randn(n_samples, *self._latent_size)
z = z.to(current_device)
x_hat = self.decode(z)
return z
def generate(self, x):
return self.forward(x)[0]