Skip to content

Commit

Permalink
Simplified competitive concatentation with torch.maximum
Browse files Browse the repository at this point in the history
  • Loading branch information
LeHenschel committed May 8, 2023
1 parent 99508f4 commit ed31585
Showing 1 changed file with 11 additions and 41 deletions.
52 changes: 11 additions & 41 deletions FastSurferCNN/models/sub_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,16 @@ def forward(self, x):
x2_gn = self.gn2(x1)

# First Maxout
x1_gn = torch.unsqueeze(x1_gn, 4) # RF 3x3 --> weighted with attention map 1
x2_gn = torch.unsqueeze(x2_gn, 4) # RF 5x5 --> weighted with attention map 2
x2 = torch.cat((x2_gn, x1_gn), dim=4) # Concatenating along the 5th dimension
x2_max, _ = torch.max(x2, 4)
x2_max = torch.maximum(x2_gn, x1_gn)
x2 = self.prelu(x2_max)

# Convolution block 3 (RF: 7x7)
x2 = self.conv2(x2)
x3_gn = self.gn3(x2)

# Second Maxout
x3_gn = torch.unsqueeze(x3_gn, 4) # RF 7x7 --> weighted with attention map 3
x2_max = torch.unsqueeze(x2_max, 4) # RF 3x3 and 5x5 from First Maxout (weighted with map 1 and 2)
x3 = torch.cat((x3_gn, x2_max), dim=4) # Concatenating along the 5th dimension
x3_max, _ = torch.max(x3, 4)

x3_max = torch.maximum(x3_gn, x2_max)
x3 = self.prelu(x3_max)

# Convolution block 4 (RF: 9x9)
Expand Down Expand Up @@ -182,32 +177,23 @@ def forward(self, x):
x1_bn = self.bn1(x0)

# First Maxout/Addition
x0_bn = torch.unsqueeze(x, 4) # Original input --> weighted with attention map 1
x1_bn = torch.unsqueeze(x1_bn, 4) # RF 3x3 --> weighted with attention map 2
x1 = torch.cat((x1_bn, x0_bn), dim=4) # Concatenate along the 5th dimension NB x C x H x W x F
x1_max, _ = torch.max(x1, 4)
x1_max = torch.maximum(x, x1_bn)
x1 = self.prelu(x1_max)

# Convolution block 2
x1 = self.conv1(x1)
x2_bn = self.bn2(x1)

# Second Maxout/Addition
x2_bn = torch.unsqueeze(x2_bn, 4) # RF 5x5 --> weighted with attention map 3
x1_max = torch.unsqueeze(x1_max, 4) # Original and 3x3 weighted with attention map 1 and 2
x2 = torch.cat((x2_bn, x1_max), dim=4) # Concatenating along the 5th dimension
x2_max, _ = torch.max(x2, 4)
x2_max = torch.maximum(x2_bn, x1_max)
x2 = self.prelu(x2_max)

# Convolution block 3
x2 = self.conv2(x2)
x3_bn = self.bn3(x2)

# Third Maxout/Addition
x3_bn = torch.unsqueeze(x3_bn, 4) # RF 7x7 --> weighted with attention map 4
x2_max = torch.unsqueeze(x2_max, 4) # orig, 3x3, 5x5 weighted with attention map 1-3
x3 = torch.cat((x3_bn, x2_max), dim=4) # Concatenating along the 5th dimension
x3_max, _ = torch.max(x3, 4)
x3_max = torch.maximum(x3_bn, x2_max)
x3 = self.prelu(x3_max)

# Convolution block 4 (end with batch-normed output to allow maxout across skip-connections)
Expand Down Expand Up @@ -301,21 +287,15 @@ def forward(self, x):
x2_bn = self.bn2(x1)

# First Maxout
x1_bn = torch.unsqueeze(x1_bn, 4) # RF 3x3
x2_bn = torch.unsqueeze(x2_bn, 4) # RF 5x5
x2 = torch.cat((x2_bn, x1_bn), dim=4) # Concatenating along the 5th dimension
x2_max, _ = torch.max(x2, 4)
x2_max = torch.maximum(x2_bn, x1_bn)
x2 = self.prelu(x2_max)

# Convolution block3 (RF: 7x7)
x2 = self.conv2(x2)
x3_bn = self.bn3(x2)

# Second Maxout
x3_bn = torch.unsqueeze(x3_bn, 4) # RF 7x7
x2_max = torch.unsqueeze(x2_max, 4) # RF 3x3 and 5x5 from First Maxout (weighted with map 1 and 2)
x3 = torch.cat((x3_bn, x2_max), dim=4) # Concatenating along the 5th dimension
x3_max, _ = torch.max(x3, 4)
x3_max = torch.maximum(x3_bn, x2_max)
x3 = self.prelu(x3_max)

# Convolution block 4 (RF: 9x9)
Expand Down Expand Up @@ -426,11 +406,7 @@ def forward(self, x, out_block, indices):
:return: processed feature maps
"""
unpool = self.unpool(x, indices)
unpool = torch.unsqueeze(unpool, 4)

out_block = torch.unsqueeze(out_block, 4)
concat = torch.cat((unpool, out_block), dim=4) # Competitive Concatenation
concat_max, _ = torch.max(concat, 4)
concat_max = torch.maximum(unpool, out_block)
out_block = super(CompetitiveDecoderBlock, self).forward(concat_max)

return out_block
Expand Down Expand Up @@ -511,21 +487,15 @@ def forward(self, x, out_block):
x2_gn = self.gn2(x1)

# First Maxout
x1_gn = torch.unsqueeze(x1_gn, 4)
x2_gn = torch.unsqueeze(x2_gn, 4) # Add Singleton Dimension along 5th
x2 = torch.cat((x2_gn, x1_gn), dim=4) # Concatenating along the 5th dimension
x2_max, _ = torch.max(x2, 4)
x2_max = torch.maximum(x1_gn, x2_gn)

x2 = self.prelu(x2_max)
# Convolution block3; 7x7
x2 = self.conv2(x2)
x3_gn = self.gn3(x2)

# Second Maxout
x3_gn = torch.unsqueeze(x3_gn, 4)
x2_max = torch.unsqueeze(x2_max, 4) # Add Singleton Dimension along 5th
x3 = torch.cat((x3_gn, x2_max), dim=4) # Concatenating along the 5th dimension
x3_max, _ = torch.max(x3, 4)
x3_max = torch.maximum(x3_gn, x2_max)

x3 = self.prelu(x3_max)
# Convolution block 4; 9x9
Expand Down

0 comments on commit ed31585

Please sign in to comment.