Semantic Segmentation : Multiclass fine tuning of DeepLabV3 with PyTorch
The code in this repository performs a fine tuning of DeepLabV3 with PyTorch for multiclass semantic segmentation.
Random result on a test image (not in dataset)
Basic dependencies are PyTorch 1.4.0 and torchvision 0.5.0.
I used a conda virtual env, where I installed the following packages :
conda install -c conda-forge -c pytorch python=3.7 pytorch torchvision cudatoolkit=10.1 opencv numpy pillow
I created a dataset from my own personal skydiving pictures.
Around 500 images were gathered and annotated using the excellent tool CVAT : https://github.com/opencv/cvat
/!\ On this repo, I only uploaded a few images in as to give an idea of the format I used.
I wrote a to easily convert one of the XML export types (LabelMe ZIP 3.0 for images) of CVAT into label images
There are 5 classes in my example:
- No-label : 0
- Person : 1
- Airplane : 2
- Ground : 3
- Sky : 4
Once you replace sample_data with your own dataset :
python sources/main_training.py ./sample_dataset ./training_output --num_classes 5 --epochs 100 --batch_size 16 --keep_feature_extract
Best value I obtained were Loss: 0.2066 and Accuracy: 0.8099 with 100 epochs The accuracy is computed as the mean of the IoU (Intersection-over-Union) for all classes.
First thing is to fetch a pretrained DeepLabV3 model.
It is pretrained on a subset of COCO train2017, on the 20 categories that are present in the Pascal VOC dataset.
model_deeplabv3 = models.segmentation.deeplabv3_resnet101(pretrained=use_pretrained, progress=True)
The auxiliary classifier is removed, and the pretrained weights are frozen.
model_deeplabv3.aux_classifier = None
for param in model_deeplabv3.parameters():
param.requires_grad = False
The pretrained classifier is replaced by a new one with a custom number of classes. Since it comes after the freeze, its weights won't be frozen. They are the ones that we will fine-tune.
model_deeplabv3.classifier = torchvision.models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)
Following data augmentation are applied to the training set :
self.transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406, 0], [0.229, 0.224, 0.225, 1])
])
For the validation set, only centered crop and normalization are applied :
self.transforms = transforms.Compose([
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406, 0], [0.229, 0.224, 0.225, 1])
])
To ensure that the same transformation is applied on the input image and the expected output label image, both of them are merged into a 4 channels image prior to transformation. They are then split back as two separate entities for the training.
image = Image.open(img_path)
label = Image.open(label_path)
# Concatenate image and label, to apply same transformation on both
image_np = np.asarray(image)
label_np = np.asarray(label)
new_shape = (image_np.shape[0], image_np.shape[1], image_np.shape[2] + 1)
image_and_label_np = np.zeros(new_shape, image_np.dtype)
image_and_label_np[:, :, 0:3] = image_np
image_and_label_np[:, :, 3] = label_np
# Convert to PIL
image_and_label = Image.fromarray(image_and_label_np)
# Apply Transforms
image_and_label = self.transforms(image_and_label)
# Extract image and label
image = image_and_label[0:3, :, :]
label = image_and_label[3, :, :].unsqueeze(0)
The chosen training loss is Cross Entropy (https://pytorch.org/docs/stable/nn.html#crossentropyloss) since it is well suited for multiclass classification problems.
# Setup the loss function
criterion = nn.CrossEntropyLoss(weight=(torch.FloatTensor(weight).to(device) if weight else None))
The optimizer is SGD with a learning rate of 0.001 and a momentum of 0.9.
Only the classifier parameters are optimized.
params_to_update = []
for name, param in model_deeplabv3.named_parameters():
if param.requires_grad:
params_to_update.append(param)
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)