-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathfusion.py
110 lines (83 loc) · 4.26 KB
/
fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import division
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
class DirectAddFuse(HybridBlock):
def __init__(self):
super(DirectAddFuse, self).__init__()
def hybrid_forward(self, F, x, residual):
xo = x + residual
return xo
class ResGlobLocaforGlobLocaChaFuse(HybridBlock):
def __init__(self, channels=64, r=4):
super(ResGlobLocaforGlobLocaChaFuse, self).__init__()
inter_channels = int(channels // r)
with self.name_scope():
self.local_att = nn.HybridSequential(prefix='local_att')
self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())
self.local_att.add(nn.Activation('relu'))
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())
self.global_att = nn.HybridSequential(prefix='global_att')
self.global_att.add(nn.GlobalAvgPool2D())
self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())
self.global_att.add(nn.Activation('relu'))
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())
self.local_att2 = nn.HybridSequential(prefix='local_att2')
self.local_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.local_att2.add(nn.BatchNorm())
self.local_att2.add(nn.Activation('relu'))
self.local_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att2.add(nn.BatchNorm())
self.global_att2 = nn.HybridSequential(prefix='global_att2')
self.global_att2.add(nn.GlobalAvgPool2D())
self.global_att2.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.global_att2.add(nn.BatchNorm())
self.global_att2.add(nn.Activation('relu'))
self.global_att2.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att2.add(nn.BatchNorm())
self.sig1 = nn.Activation('sigmoid')
self.sig2 = nn.Activation('sigmoid')
def hybrid_forward(self, F, x, residual):
xa = x + residual
xl = self.local_att(xa)
xg = self.global_att(xa)
xlg = F.broadcast_add(xl, xg)
wei = self.sig1(xlg)
xi = F.broadcast_mul(x, wei) + F.broadcast_mul(residual, 1-wei)
xl2 = self.local_att2(xi)
xg2 = self.global_att2(xi)
xlg2 = F.broadcast_add(xl2, xg2)
wei2 = self.sig2(xlg2)
xo = F.broadcast_mul(x, wei2) + F.broadcast_mul(residual, 1-wei2)
return xo
class ASKCFuse(HybridBlock):
def __init__(self, channels=64, r=4):
super(ASKCFuse, self).__init__()
inter_channels = int(channels // r)
with self.name_scope():
self.local_att = nn.HybridSequential(prefix='local_att')
self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())
self.local_att.add(nn.Activation('relu'))
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())
self.global_att = nn.HybridSequential(prefix='global_att')
self.global_att.add(nn.GlobalAvgPool2D())
self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())
self.global_att.add(nn.Activation('relu'))
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())
self.sig = nn.HybridSequential(prefix='sig')
self.sig.add(nn.Activation('sigmoid'))
def hybrid_forward(self, F, x, residual):
xa = x + residual
xl = self.local_att(xa)
xg = self.global_att(xa)
xlg = F.broadcast_add(xl, xg)
wei = self.sig(xlg)
xo = 2 * F.broadcast_mul(x, wei) + 2 * F.broadcast_mul(residual, 1-wei)
return xo