diff --git a/FastSurferCNN/models/sub_module.py b/FastSurferCNN/models/sub_module.py index c1642b07..d9d8d64c 100644 --- a/FastSurferCNN/models/sub_module.py +++ b/FastSurferCNN/models/sub_module.py @@ -75,10 +75,7 @@ def forward(self, x): x2_gn = self.gn2(x1) # First Maxout - x1_gn = torch.unsqueeze(x1_gn, 4) # RF 3x3 --> weighted with attention map 1 - x2_gn = torch.unsqueeze(x2_gn, 4) # RF 5x5 --> weighted with attention map 2 - x2 = torch.cat((x2_gn, x1_gn), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) + x2_max = torch.maximum(x2_gn, x1_gn) x2 = self.prelu(x2_max) # Convolution block 3 (RF: 7x7) @@ -86,10 +83,8 @@ def forward(self, x): x3_gn = self.gn3(x2) # Second Maxout - x3_gn = torch.unsqueeze(x3_gn, 4) # RF 7x7 --> weighted with attention map 3 - x2_max = torch.unsqueeze(x2_max, 4) # RF 3x3 and 5x5 from First Maxout (weighted with map 1 and 2) - x3 = torch.cat((x3_gn, x2_max), dim=4) # Concatenating along the 5th dimension - x3_max, _ = torch.max(x3, 4) + + x3_max = torch.maximum(x3_gn, x2_max) x3 = self.prelu(x3_max) # Convolution block 4 (RF: 9x9) @@ -182,10 +177,7 @@ def forward(self, x): x1_bn = self.bn1(x0) # First Maxout/Addition - x0_bn = torch.unsqueeze(x, 4) # Original input --> weighted with attention map 1 - x1_bn = torch.unsqueeze(x1_bn, 4) # RF 3x3 --> weighted with attention map 2 - x1 = torch.cat((x1_bn, x0_bn), dim=4) # Concatenate along the 5th dimension NB x C x H x W x F - x1_max, _ = torch.max(x1, 4) + x1_max = torch.maximum(x, x1_bn) x1 = self.prelu(x1_max) # Convolution block 2 @@ -193,10 +185,7 @@ def forward(self, x): x2_bn = self.bn2(x1) # Second Maxout/Addition - x2_bn = torch.unsqueeze(x2_bn, 4) # RF 5x5 --> weighted with attention map 3 - x1_max = torch.unsqueeze(x1_max, 4) # Original and 3x3 weighted with attention map 1 and 2 - x2 = torch.cat((x2_bn, x1_max), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) + x2_max = torch.maximum(x2_bn, x1_max) x2 = self.prelu(x2_max) # Convolution block 3 @@ -204,10 +193,7 @@ def forward(self, x): x3_bn = self.bn3(x2) # Third Maxout/Addition - x3_bn = torch.unsqueeze(x3_bn, 4) # RF 7x7 --> weighted with attention map 4 - x2_max = torch.unsqueeze(x2_max, 4) # orig, 3x3, 5x5 weighted with attention map 1-3 - x3 = torch.cat((x3_bn, x2_max), dim=4) # Concatenating along the 5th dimension - x3_max, _ = torch.max(x3, 4) + x3_max = torch.maximum(x3_bn, x2_max) x3 = self.prelu(x3_max) # Convolution block 4 (end with batch-normed output to allow maxout across skip-connections) @@ -301,10 +287,7 @@ def forward(self, x): x2_bn = self.bn2(x1) # First Maxout - x1_bn = torch.unsqueeze(x1_bn, 4) # RF 3x3 - x2_bn = torch.unsqueeze(x2_bn, 4) # RF 5x5 - x2 = torch.cat((x2_bn, x1_bn), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) + x2_max = torch.maximum(x2_bn, x1_bn) x2 = self.prelu(x2_max) # Convolution block3 (RF: 7x7) @@ -312,10 +295,7 @@ def forward(self, x): x3_bn = self.bn3(x2) # Second Maxout - x3_bn = torch.unsqueeze(x3_bn, 4) # RF 7x7 - x2_max = torch.unsqueeze(x2_max, 4) # RF 3x3 and 5x5 from First Maxout (weighted with map 1 and 2) - x3 = torch.cat((x3_bn, x2_max), dim=4) # Concatenating along the 5th dimension - x3_max, _ = torch.max(x3, 4) + x3_max = torch.maximum(x3_bn, x2_max) x3 = self.prelu(x3_max) # Convolution block 4 (RF: 9x9) @@ -426,11 +406,7 @@ def forward(self, x, out_block, indices): :return: processed feature maps """ unpool = self.unpool(x, indices) - unpool = torch.unsqueeze(unpool, 4) - - out_block = torch.unsqueeze(out_block, 4) - concat = torch.cat((unpool, out_block), dim=4) # Competitive Concatenation - concat_max, _ = torch.max(concat, 4) + concat_max = torch.maximum(unpool, out_block) out_block = super(CompetitiveDecoderBlock, self).forward(concat_max) return out_block @@ -511,10 +487,7 @@ def forward(self, x, out_block): x2_gn = self.gn2(x1) # First Maxout - x1_gn = torch.unsqueeze(x1_gn, 4) - x2_gn = torch.unsqueeze(x2_gn, 4) # Add Singleton Dimension along 5th - x2 = torch.cat((x2_gn, x1_gn), dim=4) # Concatenating along the 5th dimension - x2_max, _ = torch.max(x2, 4) + x2_max = torch.maximum(x1_gn, x2_gn) x2 = self.prelu(x2_max) # Convolution block3; 7x7 @@ -522,10 +495,7 @@ def forward(self, x, out_block): x3_gn = self.gn3(x2) # Second Maxout - x3_gn = torch.unsqueeze(x3_gn, 4) - x2_max = torch.unsqueeze(x2_max, 4) # Add Singleton Dimension along 5th - x3 = torch.cat((x3_gn, x2_max), dim=4) # Concatenating along the 5th dimension - x3_max, _ = torch.max(x3, 4) + x3_max = torch.maximum(x3_gn, x2_max) x3 = self.prelu(x3_max) # Convolution block 4; 9x9