Skip to content

Commit

Permalink
ACON Activation batch-size 1 bug patch (ultralytics#2901)
Browse files Browse the repository at this point in the history
* ACON Activation batch-size 1 bug path

This is not a great solution to nmaac/acon#4 but it's all I could think of at the moment.

WARNING: YOLOv5 models with MetaAconC() activations are incapable of running inference at batch-size 1 properly due to a known bug in nmaac/acon#4 with no known solution.

* Update activations.py

* Update activations.py

* Update activations.py

* Update activations.py
  • Loading branch information
glenn-jocher authored Apr 25, 2021
1 parent c0d3f80 commit 9c7bb5a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
self.bn1 = nn.BatchNorm2d(c2)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
self.bn2 = nn.BatchNorm2d(c1)
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
# self.bn1 = nn.BatchNorm2d(c2)
# self.bn2 = nn.BatchNorm2d(c1)

def forward(self, x):
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 comments on commit 9c7bb5a

Please sign in to comment.