Skip to content

Commit

Permalink
Merge pull request #4 from MBTMBTMBT/update-model
Browse files Browse the repository at this point in the history
Update model
  • Loading branch information
MBTMBTMBT authored Feb 23, 2024
2 parents 1cb2e3b + ae3975c commit 4f78c7e
Show file tree
Hide file tree
Showing 12 changed files with 595 additions and 451 deletions.
240 changes: 42 additions & 198 deletions common/helpers/torch_module/src/torch_module/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,134 +4,7 @@
import torchvision.models as models


class CombinedModelNoRegression(nn.Module):
def __init__(self, segment_model: nn.Module, predict_model: nn.Module, cat_layers: int = None):
super(CombinedModelNoRegression, self).__init__()
self.segment_model = segment_model
self.predict_model = predict_model
self.cat_layers = cat_layers

def forward(self, x: torch.Tensor):
seg_masks = self.segment_model(x)
if self.cat_layers:
seg_masks_ = seg_masks[:, 0:self.cat_layers]
x = torch.cat((x, seg_masks_), dim=1)
else:
x = torch.cat((x, seg_masks), dim=1)
logic_outputs = self.predict_model(x)
return seg_masks, logic_outputs


class ASPP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ASPP, self).__init__()
self.atrous_block1 = nn.Conv2d(in_channels, out_channels, 1, 1)
self.atrous_block6 = nn.Conv2d(
in_channels, out_channels, 3, padding=6, dilation=6)
self.atrous_block12 = nn.Conv2d(
in_channels, out_channels, 3, padding=12, dilation=12)
self.atrous_block18 = nn.Conv2d(
in_channels, out_channels, 3, padding=18, dilation=18)
self.conv_out = nn.Conv2d(out_channels * 4, out_channels, 1, 1)

def forward(self, x):
x1 = self.atrous_block1(x)
x6 = self.atrous_block6(x)
x12 = self.atrous_block12(x)
x18 = self.atrous_block18(x)
x = torch.cat([x1, x6, x12, x18], dim=1)
return self.conv_out(x)


class DeepLabV3PlusMobileNetV3(nn.Module):
def __init__(self, num_classes, in_channels=3, sigmoid=True):
super(DeepLabV3PlusMobileNetV3, self).__init__()
self.sigmoid = sigmoid
mobilenet_v3 = models.mobilenet_v3_large(pretrained=True)

if in_channels != 3:
mobilenet_v3.features[0][0] = nn.Conv2d(
in_channels, 16, kernel_size=3, stride=2, padding=1, bias=False
)

self.encoder = mobilenet_v3.features

intermediate_channel = self.encoder[-1].out_channels
self.aspp = ASPP(intermediate_channel, 256)

self.decoder = nn.Sequential(
# Concatenated with original input
nn.Conv2d(256 + in_channels, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, kernel_size=1)
)

def forward(self, x):
original_input = x
x_encoded = self.encoder(x)
x_aspp = self.aspp(x_encoded)

x = F.interpolate(
x_aspp, size=original_input.shape[2:], mode='bilinear', align_corners=False)
# Concatenate with original input
x = torch.cat([x, original_input], dim=1)
x = self.decoder(x)

if self.sigmoid:
x = torch.sigmoid(x)

return x


class MultiLabelMobileNetV3Small(nn.Module):
def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=True):
super(MultiLabelMobileNetV3Small, self).__init__()
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=pretrained)
self.sigmoid = sigmoid

if input_channels != 3:
mobilenet_v3_small.features[0][0] = nn.Conv2d(
input_channels, 16, kernel_size=3, stride=2, padding=1, bias=False
)

self.model = mobilenet_v3_small

num_ftrs = self.model.classifier[3].in_features
self.model.classifier[3] = nn.Linear(num_ftrs, num_labels)

def forward(self, x):
x = self.model(x)
if self.sigmoid:
x = torch.sigmoid(x)
return x


class MultiLabelMobileNetV3Large(nn.Module):
def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=True):
super(MultiLabelMobileNetV3Large, self).__init__()
mobilenet_v3_small = models.mobilenet_v3_large(pretrained=pretrained)
self.sigmoid = sigmoid

if input_channels != 3:
mobilenet_v3_small.features[0][0] = nn.Conv2d(
input_channels, 16, kernel_size=3, stride=2, padding=1, bias=False
)

self.model = mobilenet_v3_small

num_ftrs = self.model.classifier[3].in_features
self.model.classifier[3] = nn.Linear(num_ftrs, num_labels)

def forward(self, x):
x = self.model(x)
if self.sigmoid:
x = torch.sigmoid(x)
return x


def x2conv(in_channels, out_channels, inner_channels=None):
def X2conv(in_channels, out_channels, inner_channels=None):
inner_channels = out_channels // 2 if inner_channels is None else inner_channels
down_conv = nn.Sequential(
nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False),
Expand All @@ -143,78 +16,39 @@ def x2conv(in_channels, out_channels, inner_channels=None):
return down_conv


class Encoder(nn.Module):
def __init__(self, in_channels, out_channels):
super(Encoder, self).__init__()
self.down_conv = x2conv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

def forward(self, x):
x = self.down_conv(x)
x = self.pool(x)
return x


class Decoder(nn.Module):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, skip_channels, out_channels):
super(Decoder, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.up_conv = x2conv(in_channels, out_channels)
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.up_conv = X2conv(out_channels + skip_channels, out_channels)

def forward(self, x_copy, x, interpolate=True):
def forward(self, x_copy, x):
x = self.up(x)

if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)):
if interpolate:
# Iterpolating instead of padding
x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)),
mode="bilinear", align_corners=True)
else:
# Padding in case the incomping volumes are of different sizes
diffY = x_copy.size()[2] - x.size()[2]
diffX = x_copy.size()[3] - x.size()[3]
x = F.pad(x, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))

# Concatenate
x = torch.cat([x_copy, x], dim=1)
if x.size(2) != x_copy.size(2) or x.size(3) != x_copy.size(3):
x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)), mode='bilinear', align_corners=True)
x = torch.cat((x_copy, x), dim=1)
x = self.up_conv(x)
return x


class UNetWithResnet18Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels):
super(UNetWithResnet18Encoder.Decoder, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.up_conv = x2conv(out_channels + skip_channels, out_channels)

def forward(self, x_copy, x):
x = self.up(x)
if x.size(2) != x_copy.size(2) or x.size(3) != x_copy.size(3):
x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)), mode='bilinear', align_corners=True)
x = torch.cat((x_copy, x), dim=1)
x = self.up_conv(x)
return x

class UNetWithResnetEncoder(nn.Module):
def __init__(self, num_classes, in_channels=3, freeze_bn=False, sigmoid=True):
super(UNetWithResnet18Encoder, self).__init__()
super(UNetWithResnetEncoder, self).__init__()
self.sigmoid = sigmoid
resnet18 = models.resnet18(pretrained=False)

self.resnet = models.resnet34(pretrained=True) # Initialize with a ResNet model
if in_channels != 3:
resnet18.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

self.encoder1 = nn.Sequential(resnet18.conv1, resnet18.bn1, resnet18.relu)
self.encoder2 = resnet18.layer1
self.encoder3 = resnet18.layer2
self.encoder4 = resnet18.layer3
self.encoder5 = resnet18.layer4
self.encoder1 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu)
self.encoder2 = self.resnet.layer1
self.encoder3 = self.resnet.layer2
self.encoder4 = self.resnet.layer3
self.encoder5 = self.resnet.layer4

self.up1 = UNetWithResnet18Encoder.Decoder(512, 256, 256)
self.up2 = UNetWithResnet18Encoder.Decoder(256, 128, 128)
self.up3 = UNetWithResnet18Encoder.Decoder(128, 64, 64)
self.up4 = UNetWithResnet18Encoder.Decoder(64, 64, 64)
self.up1 = Decoder(512, 256, 256)
self.up2 = Decoder(256, 128, 128)
self.up3 = Decoder(128, 64, 64)
self.up4 = Decoder(64, 64, 64)

self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
self._initialize_weights()
Expand All @@ -233,7 +67,7 @@ def _initialize_weights(self):
module.bias.data.zero_()

def forward(self, x):
x1 = self.encoder1(x)
x1 = self.encoder1(x)
x2 = self.encoder2(x1)
x3 = self.encoder3(x2)
x4 = self.encoder4(x3)
Expand All @@ -243,10 +77,10 @@ def forward(self, x):
x = self.up2(x3, x)
x = self.up3(x2, x)
x = self.up4(x1, x)
x = F.interpolate(x, size=(x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
x = F.interpolate(x, size=(x.size(2) * 2, x.size(3) * 2), mode='bilinear', align_corners=True)

x = self.final_conv(x)

if self.sigmoid:
x = torch.sigmoid(x)
return x
Expand All @@ -256,18 +90,23 @@ def freeze_bn(self):
if isinstance(module, nn.BatchNorm2d):
module.eval()

def unfreeze_bn(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.train()


class MultiLabelResNet(nn.Module):
def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=False,):
super(MultiLabelResNet, self).__init__()
self.model = models.resnet18(pretrained=pretrained)
self.model = models.resnet34(pretrained=pretrained)
self.sigmoid = sigmoid

if input_channels != 3:
self.model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

num_ftrs = self.model.fc.in_features

self.model.fc = nn.Linear(num_ftrs, num_labels)

def forward(self, x):
Expand All @@ -278,21 +117,26 @@ def forward(self, x):


class CombinedModel(nn.Module):
def __init__(self, segment_model: nn.Module, predict_model: nn.Module, cat_layers:int=None):
def __init__(self, segment_model: nn.Module, predict_model: nn.Module, cat_layers: int=None):
super(CombinedModel, self).__init__()
self.segment_model = segment_model
self.predict_model = predict_model
self.cat_layers = cat_layers
self.freeze_seg = False

def forward(self, x: torch.Tensor):
seg_masks = self.segment_model(x)

seg_masks_ = seg_masks.detach()
if self.cat_layers:
seg_masks_ = seg_masks[:, 0:self.cat_layers]
seg_masks_ = seg_masks_[:, 0:self.cat_layers]
x = torch.cat((x, seg_masks_), dim=1)
else:
x = torch.cat((x, seg_masks), dim=1)

x = torch.cat((x, seg_masks_), dim=1)
logic_outputs = self.predict_model(x)
return seg_masks, logic_outputs

def freeze_segment_model(self):
self.segment_model.eval()

def unfreeze_segment_model(self):
self.segment_model.train()
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ def load_model_cached(dataset: str) -> None:
model = loaded_models[dataset]
else:
if dataset == 'resnet50':
# name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16)
# rospy.logwarn(name)
# model = load_model(name)
model = load_model('/home/rexy/.keras/tf-bodypix/3fe1b130a0f20e98340612c099b50c18--tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16')
# model = load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16))
elif dataset == 'mobilenet50':
model = load_model(download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16))
name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16)
rospy.logwarn(name)
model = load_model(name)
else:
model = load_model(dataset)

Expand Down
4 changes: 3 additions & 1 deletion common/vision/lasr_vision_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ add_message_files(
BodyPixMaskRequest.msg
ColourPrediction.msg
FeatureWithColour.msg
# Description.msg
)

## Generate services in the 'srv' folder
Expand All @@ -59,7 +60,8 @@ add_service_files(
YoloDetection.srv
BodyPixDetection.srv
TorchFaceFeatureDetection.srv
Recognise.srv
TorchFaceFeatureDetectionDescription.srv
# Recognise.srv
)

## Generate actions in the 'action' folder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Image to run inference on
sensor_msgs/Image image_raw

uint8[] head_mask_data # For serialized array data
uint32[] head_mask_shape # To store the shape of the array
string head_mask_dtype # Data type of the array elements

uint8[] torso_mask_data
uint32[] torso_mask_shape
string torso_mask_dtype
---

# Detection result
string description
Loading

0 comments on commit 4f78c7e

Please sign in to comment.