-
Notifications
You must be signed in to change notification settings - Fork 0
/
blocks.py
22 lines (21 loc) · 851 Bytes
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearAttentionBlock(nn.Module):
def __init__(self, in_features, normalize_attn=True):
super(LinearAttentionBlock, self).__init__()
self.normalize_attn = normalize_attn
self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
def forward(self, l, g):
N, C, W, H = l.size()
c = self.op(l+g) # batch_sizex1xWxH
if self.normalize_attn:
a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
else:
a = torch.sigmoid(c)
g = torch.mul(a.expand_as(l), l)
if self.normalize_attn:
g = g.view(N,C,-1).sum(dim=2) # batch_sizexC
else:
g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
return c.view(N,1,W,H), g