forked from aioz-ai/MICCAI19-MedVQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbc.py
executable file
·94 lines (81 loc) · 3.47 KB
/
bc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
This code is from Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang's repository.
https://github.com/jnhwkim/ban-vqa
"""
from __future__ import print_function
import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
from fc import FCNet
class BCNet(nn.Module):
"""Simple class for non-linear bilinear connect network
"""
def __init__(self, v_dim, q_dim, h_dim, h_out, act="ReLU", dropout=[0.2, 0.5], k=3):
super(BCNet, self).__init__()
self.c = 32
self.k = k
self.v_dim = v_dim
self.q_dim = q_dim
self.h_dim = h_dim
self.h_out = h_out
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0])
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0])
self.dropout = nn.Dropout(dropout[1]) # attention
if 1 < k:
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
if None == h_out:
pass
elif h_out <= self.c:
self.h_mat = nn.Parameter(
torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()
)
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
else:
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
def forward(self, v, q):
if None == self.h_out:
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
logits = d_.transpose(1, 2).transpose(2, 3) # b x v x q x h_dim
return logits
# broadcast Hadamard product, matrix-matrix production
# fast computation but memory inefficient
# epoch 1, time: 157.84
elif self.h_out <= self.c:
v_ = self.dropout(self.v_net(v)).unsqueeze(1)
q_ = self.q_net(q)
h_ = v_ * self.h_mat # broadcast, b x h_out x v x h_dim
logits = torch.matmul(
h_, q_.unsqueeze(1).transpose(2, 3)
) # b x h_out x v x q
logits = logits + self.h_bias
return logits # b x h_out x v x q
# batch outer product, linear projection
# memory efficient but slow computation
# epoch 1, time: 304.87
else:
v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3)
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
logits = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
return logits.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
def forward_with_weights(self, v, q, w):
v_ = self.v_net(v).transpose(1, 2).unsqueeze(2) # b x d x 1 x v
q_ = self.q_net(q).transpose(1, 2).unsqueeze(3) # b x d x q x 1
logits = torch.matmul(
torch.matmul(v_.float(), w.unsqueeze(1).float()), q_.float()
).type_as(
v_
) # b x d x 1 x 1
# logits = torch.matmul(torch.matmul(v_, w.unsqueeze(1)), q_)# b x d x 1 x 1
logits = logits.squeeze(3).squeeze(2)
if 1 < self.k:
logits = logits.unsqueeze(1) # b x 1 x d
logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
return logits
if __name__ == "__main__":
net = BCNet(1024, 1024, 1024, 1024).cuda()
x = torch.Tensor(512, 36, 1024).cuda()
y = torch.Tensor(512, 14, 1024).cuda()
out = net.forward(x, y)