forked from thuiar/MMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TFN.py
106 lines (89 loc) · 5.12 KB
/
TFN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
paper: Tensor Fusion Network for Multimodal Sentiment Analysis
From: https://github.com/A2Zadeh/TensorFusionNetwork
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform, xavier_normal, orthogonal
from models.subNets.FeatureNets import SubNet, TextSubNet
class TFN(nn.Module):
'''
Implements the Tensor Fusion Networks for multimodal sentiment analysis as is described in:
Zadeh, Amir, et al. "Tensor fusion network for multimodal sentiment analysis." EMNLP 2017 Oral.
'''
def __init__(self, args):
'''
Args:
input_dims - a length-3 tuple, contains (audio_dim, video_dim, text_dim)
hidden_dims - another length-3 tuple, similar to input_dims
text_out - int, specifying the resulting dimensions of the text subnetwork
dropouts - a length-4 tuple, contains (audio_dropout, video_dropout, text_dropout, post_fusion_dropout)
post_fusion_dim - int, specifying the size of the sub-networks after tensorfusion
Output:
(return value in forward) a scalar value between -3 and 3
'''
super(TFN, self).__init__()
# dimensions are specified in the order of audio, video and text
self.text_in, self.audio_in, self.video_in = args.feature_dims
self.text_hidden, self.audio_hidden, self.video_hidden = args.hidden_dims
self.text_out= args.text_out
self.post_fusion_dim = args.post_fusion_dim
self.audio_prob, self.video_prob, self.text_prob, self.post_fusion_prob = args.dropouts
# define the pre-fusion subnetworks
self.audio_subnet = SubNet(self.audio_in, self.audio_hidden, self.audio_prob)
self.video_subnet = SubNet(self.video_in, self.video_hidden, self.video_prob)
self.text_subnet = TextSubNet(self.text_in, self.text_hidden, self.text_out, dropout=self.text_prob)
# define the post_fusion layers
self.post_fusion_dropout = nn.Dropout(p=self.post_fusion_prob)
self.post_fusion_layer_1 = nn.Linear((self.text_out + 1) * (self.video_hidden + 1) * (self.audio_hidden + 1), self.post_fusion_dim)
self.post_fusion_layer_2 = nn.Linear(self.post_fusion_dim, self.post_fusion_dim)
self.post_fusion_layer_3 = nn.Linear(self.post_fusion_dim, 1)
# in TFN we are doing a regression with constrained output range: (-3, 3), hence we'll apply sigmoid to output
# shrink it to (0, 1), and scale\shift it back to range (-3, 3)
self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
def forward(self, text_x, audio_x, video_x):
'''
Args:
audio_x: tensor of shape (batch_size, audio_in)
video_x: tensor of shape (batch_size, video_in)
text_x: tensor of shape (batch_size, sequence_len, text_in)
'''
audio_x = audio_x.squeeze(1)
video_x = video_x.squeeze(1)
audio_h = self.audio_subnet(audio_x)
video_h = self.video_subnet(video_x)
text_h = self.text_subnet(text_x)
batch_size = audio_h.data.shape[0]
# next we perform "tensor fusion", which is essentially appending 1s to the tensors and take Kronecker product
add_one = torch.ones(size=[batch_size, 1], requires_grad=False).type_as(audio_h).to(text_x.device)
_audio_h = torch.cat((add_one, audio_h), dim=1)
_video_h = torch.cat((add_one, video_h), dim=1)
_text_h = torch.cat((add_one, text_h), dim=1)
# _audio_h has shape (batch_size, audio_in + 1), _video_h has shape (batch_size, _video_in + 1)
# we want to perform outer product between the two batch, hence we unsqueenze them to get
# (batch_size, audio_in + 1, 1) X (batch_size, 1, video_in + 1)
# fusion_tensor will have shape (batch_size, audio_in + 1, video_in + 1)
fusion_tensor = torch.bmm(_audio_h.unsqueeze(2), _video_h.unsqueeze(1))
# next we do kronecker product between fusion_tensor and _text_h. This is even trickier
# we have to reshape the fusion tensor during the computation
# in the end we don't keep the 3-D tensor, instead we flatten it
fusion_tensor = fusion_tensor.view(-1, (self.audio_hidden + 1) * (self.video_hidden + 1), 1)
fusion_tensor = torch.bmm(fusion_tensor, _text_h.unsqueeze(1)).view(batch_size, -1)
post_fusion_dropped = self.post_fusion_dropout(fusion_tensor)
post_fusion_y_1 = F.relu(self.post_fusion_layer_1(post_fusion_dropped), inplace=True)
post_fusion_y_2 = F.relu(self.post_fusion_layer_2(post_fusion_y_1), inplace=True)
post_fusion_y_3 = torch.sigmoid(self.post_fusion_layer_3(post_fusion_y_2))
output = post_fusion_y_3 * self.output_range + self.output_shift
res = {
'Feature_t': text_h,
'Feature_a': audio_h,
'Feature_v': video_h,
'Feature_f': fusion_tensor,
'M': output
}
return res