Skip to content

Commit

Permalink
Added new model structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
MBTMBTMBT committed May 13, 2024
1 parent 9b07953 commit aea2739
Showing 1 changed file with 132 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,138 @@ def unfreeze_segment_model(self):
self.segment_model.train()


class SegmentPredictor(nn.Module):
def __init__(self, num_masks, num_labels, in_channels=3, sigmoid=True):
super(SegmentPredictor, self).__init__()
self.sigmoid = sigmoid
self.resnet = models.resnet18(pretrained=True)

# Adapt ResNet to handle different input channel sizes
if in_channels != 3:
self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Encoder layers
self.encoder1 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder2 = self.resnet.layer1
self.encoder3 = self.resnet.layer2
self.encoder4 = self.resnet.layer3
self.encoder5 = self.resnet.layer4

# Decoder layers
# resnet18/34
self.up1 = Decoder(512, 256, 256)
self.up2 = Decoder(256, 128, 128)
self.up3 = Decoder(128, 64, 64)
self.up4 = Decoder(64, 64, 64)

# resnet50/101/152
# self.up1 = Decoder(2048, 1024, 1024)
# self.up2 = Decoder(1024, 512, 512)
# self.up3 = Decoder(512, 256, 256)
# self.up4 = Decoder(256, 64, 64)

# Segmentation head
self.final_conv = nn.Conv2d(64, num_masks, kernel_size=1)

# Classification head
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.predictor_cnn_extension = nn.Sequential(
nn.Conv2d(512, 2048, kernel_size=3, padding=1), # resnet18/34
# nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.01),
nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.01),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 256), # resnet50/101/152
nn.LeakyReLU(negative_slope=0.01),
nn.Dropout(p=0.5),
nn.Linear(256, 256),
nn.LeakyReLU(negative_slope=0.01),
nn.Dropout(p=0.5),
nn.Linear(256, num_labels)
)

def forward(self, x):
x1 = self.encoder1(x)
x2 = self.encoder2(x1)
x3 = self.encoder3(x2)
x4 = self.encoder4(x3)
x5 = self.encoder5(x4)

x = self.up1(x4, x5)
x = self.up2(x3, x)
x = self.up3(x2, x)
x = self.up4(x1, x)
x = F.interpolate(x, size=(x.size(2) * 2, x.size(3) * 2), mode='bilinear', align_corners=True)

mask = self.final_conv(x)

# Predicting the labels using features from the last encoder output
x_cls = self.predictor_cnn_extension(x5)
x_cls = self.global_pool(x_cls) # Use the feature map from the last encoder layer
x_cls = x_cls.view(x_cls.size(0), -1)
labels = self.classifier(x_cls)

if self.sigmoid:
mask = torch.sigmoid(mask)
labels = torch.sigmoid(labels)

return mask, labels


class SegmentPredictorBbox(SegmentPredictor):
def __init__(self, num_masks, num_labels, num_bbox_classes, in_channels=3, sigmoid=True):
super(SegmentPredictorBbox, self).__init__(num_masks, num_labels, in_channels, sigmoid)
self.num_bbox_classes = num_bbox_classes
self.bbox_cnn_extension = nn.Sequential(
nn.Conv2d(512, 2048, kernel_size=3, padding=1), # resnet18/34
# nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.01),
nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.01),
)
self.bbox_generator = nn.Sequential(
nn.Linear(2048, 256),
nn.LeakyReLU(negative_slope=0.01),
nn.Linear(256, 256),
nn.LeakyReLU(negative_slope=0.01),
nn.Linear(256, num_bbox_classes * 4)
)

def forward(self, x):
x1 = self.encoder1(x)
x2 = self.encoder2(x1)
x3 = self.encoder3(x2)
x4 = self.encoder4(x3)
x5 = self.encoder5(x4)

x = self.up1(x4, x5)
x = self.up2(x3, x)
x = self.up3(x2, x)
x = self.up4(x1, x)
x = F.interpolate(x, size=(x.size(2) * 2, x.size(3) * 2), mode='bilinear', align_corners=True)

mask = self.final_conv(x)

# Predicting the labels using features from the last encoder output
x_cls = self.predictor_cnn_extension(x5)
x_cls = self.global_pool(x_cls) # Use the feature map from the last encoder layer
x_cls = x_cls.view(x_cls.size(0), -1)
labels = self.classifier(x_cls)
x_bbox = self.bbox_cnn_extension(x5)
x_bbox = self.global_pool(x_bbox)
x_bbox = x_bbox.view(x_bbox.size(0), -1)
bboxes = self.bbox_generator(x_bbox).view(-1, self.num_bbox_classes, 4)

# no sigmoid for bboxes.
if self.sigmoid:
mask = torch.sigmoid(mask)
labels = torch.sigmoid(labels)

return mask, labels, bboxes


class Predictor:
def __init__(
self,
Expand Down

0 comments on commit aea2739

Please sign in to comment.