-
Notifications
You must be signed in to change notification settings - Fork 0
/
Matrix.py
120 lines (98 loc) · 3.88 KB
/
Matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self, matrixSize=32):
super(CNN,self).__init__()
self.convs = nn.Sequential(nn.Conv2d(2048, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, matrixSize, 3, 1, 1))
self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
def forward(self,x):
out = self.convs(x)
# 32x8x8
b,c,h,w = out.size()
out = out.view(b,c,-1)
# 32x64
out = torch.bmm(out,out.transpose(1,2)).div(h*w)
# 32x32
out = out.view(out.size(0),-1)
return self.fc(out)
class VAE(nn.Module):
def __init__(self, z_dim):
super(VAE,self).__init__()
# 32x8x8
self.encode = nn.Sequential(nn.Linear(2048, 2*z_dim),
)
self.bn = nn.BatchNorm1d(z_dim)
self.decode = nn.Sequential(nn.Linear(z_dim, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, 2048),
)
def reparameterize(self, mu, logvar):
mu = self.bn(mu)
std = torch.exp(logvar)
eps = torch.randn_like(std)
return mu + std
def forward(self, x):
# 32x8x8
b,c,h = x.size()
x = x.view(b,-1)
z_q_mu, z_q_logvar = self.encode(x).chunk(2, dim=1)
# reparameterize
z_q = self.reparameterize(z_q_mu, z_q_logvar)
out = self.decode(z_q)
out = out.view(b,c,h)
KL = torch.sum(0.5 * (z_q_mu.pow(2) + z_q_logvar.exp().pow(2) - 1) - z_q_logvar)
return out, KL
class MulLayer(nn.Module):
def __init__(self, z_dim, matrixSize=32):
super(MulLayer,self).__init__()
self.snet = CNN(matrixSize)
self.cnet = CNN(matrixSize)
self.VAE = VAE(z_dim)
self.matrixSize = matrixSize
self.compress = nn.Conv2d(2048,matrixSize, 1, 1, 0)
self.unzip = nn.Conv2d(matrixSize,2048, 1, 1, 0)
# if(layer == 'r41'):
# self.compress = nn.Conv2d(512,matrixSize,1,1,0)
# self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
# elif(layer == 'r31'):
# self.compress = nn.Conv2d(256,matrixSize,1,1,0)
# self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
self.transmatrix = None
def forward(self,cF,sF,trans=True):
cFBK = cF.clone()
cb,cc,ch,cw = cF.size()
cFF = cF.view(cb,cc,-1)
cMean = torch.mean(cFF,dim=2,keepdim=True)
cMean = cMean.unsqueeze(3)
cMean = cMean.expand_as(cF)
cF = cF - cMean
sb,sc,sh,sw = sF.size()
sFF = sF.view(sb,sc,-1)
sMean = torch.mean(sFF,dim=2,keepdim=True)
sMean, KL = self.VAE(sMean)
sMean = sMean.unsqueeze(3)
sMeanC = sMean.expand_as(cF)
sMeanS = sMean.expand_as(sF)
sF = sF - sMeanS
compress_content = self.compress(cF)
b,c,h,w = compress_content.size()
compress_content = compress_content.view(b,c,-1)
if(trans):
cMatrix = self.cnet(cF)
sMatrix = self.snet(sF)
sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
transmatrix = torch.bmm(sMatrix,cMatrix)
transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
out = self.unzip(transfeature.view(b,c,h,w))
out = out + sMeanC
return out, transmatrix
else:
out = self.unzip(compress_content.view(b,c,h,w))
out = out + cMean
return out