-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
73 lines (51 loc) · 1.44 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from torch import nn
import torch
class encoder(nn.Module):
def __init__(self):
super().__init__()
convnet = []
c_hid = [3, 32, 64, 128, 256, 512]
for i in range(5):
convnet.extend([
nn.Conv2d(c_hid[i], c_hid[i+1], 3, stride=2, padding=1),
nn.BatchNorm2d(c_hid[i+1]),
nn.LeakyReLU()
])
convnet.append(nn.Flatten())
convnet.append(nn.Linear(8192, 512))
mlp = []
c_hid = [1024, 512, 512, 512, 512, 512]
for i in range(4):
mlp.extend([
nn.Linear(c_hid[i], c_hid[i+1]),
nn.ReLU(),
])
self.convnet = nn.Sequential(*convnet)
self.mlp = nn.Sequential(*mlp)
self.to_mu = nn.Linear(512, 256)
self.to_var = nn.Linear(512, 256)
def forward(self, x, clip_enc):
x = self.convnet(x)
x = torch.cat([x, clip_enc], dim=-1)
x = self.mlp(x)
return self.to_mu(x), self.to_var(x)
class decoder(nn.Module):
def __init__(self):
super().__init__()
mlp = []
c_hid = [768, 512, 512, 512, 512]
for i in range(4):
mlp.extend([
nn.Linear(c_hid[i], c_hid[i+1]),
nn.ReLU(),
])
self.mlp = nn.Sequential(*mlp)
def forward(self, mu, log_var, clip_enc):
z = self.sample(mu, log_var)
z = torch.cat([z, clip_enc], dim=-1)
z = self.mlp(z)
return z
def sample(self, mu, log_var):
std = torch.exp(log_var*0.5)
eps = torch.randn_like(std)
return mu + eps * std