-
Notifications
You must be signed in to change notification settings - Fork 1
/
drug_cell_attention.py
executable file
·88 lines (77 loc) · 3.57 KB
/
drug_cell_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch.nn as nn
import torch
class MultiHeadAttention(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device).double()
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim=-1)
x = torch.matmul(self.dropout(attention), V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
class PositionwiseFeedforward(nn.Module):
def __init__(self, hid_dim, pf_dim, dropout):
super().__init__()
self.fc_1 = nn.Linear(hid_dim, pf_dim)
self.fc_2 = nn.Linear(pf_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout(torch.relu(self.fc_1(x)))
x = self.fc_2(x)
return x
class DrugCellAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
super().__init__()
self.layer_norm = nn.LayerNorm(hid_dim)
self.self_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
self.encoder_attention = MultiHeadAttention(hid_dim, n_heads, dropout, device)
self.positionwise_feedforward = PositionwiseFeedforward(hid_dim, pf_dim, dropout)
self.fc = nn.Linear(2 * hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, enc_src, trg_mask, src_mask):
# self attention
_trg_enc, _ = self.encoder_attention(trg, enc_src, enc_src, src_mask)
_trg_dec, _ = self.self_attention(trg, trg, trg, trg_mask)
# _trg = _trg_enc + _trg_dec
_trg = torch.cat((_trg_enc, _trg_dec), dim=2)
_trg = torch.relu(self.fc(_trg))
_trg = _trg.squeeze(1)
trg = self.layer_norm(trg + self.dropout(_trg))
# [batch * hid dim]
_trg = self.positionwise_feedforward(trg)
trg = self.layer_norm(trg + self.dropout(_trg))
return trg, None
class DrugCellAttention(nn.Module):
def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device):
super().__init__()
self.device = device
self.layers = nn.ModuleList([DrugCellAttentionLayer(hid_dim,n_heads, pf_dim, dropout, device)
for _ in range(n_layers)])
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, trg, enc_src, trg_mask, src_mask):
for layer in self.layers:
trg, attention = layer(trg, enc_src, trg_mask, src_mask)
output = self.fc_out(trg)
return output, attention