Skip to content

Using Deep Learning To Identify And Classify Building Damage

License

Notifications You must be signed in to change notification settings

nimaafshar/metadamagenet

Β 
Β 

Repository files navigation

MetaDamageNet

Using Deep Learning To Identify And Classify Building Damage

This project is my bachelor's thesis at AmirKabir University of Technology. Some ideas for this project are borrowed from xview2 first place solution 1 repository. I used the mentioned repository as a baseline and refactored its code. Thus, this project covers models and experiments of the mentioned repo and contributes more to the same problem of damage assessment for buildings.

Environment Setup

git clone https://github.com/nimaafshar/metadamagenet.git
cd metadamagenet/
pip install -r requirements.txt

Examples

Table Of Contents

Dataset

We are using the xview2 2 challenge dataset, namely Xbd3, as the dataset for our project. This dataset contains pairs of pre and post-disaster images from 19 natural disasters worldwide, including fires, hurricanes, floods, and earthquakes. Each sample in the dataset consists of a pre-disaster image with its building annotations and a post-disaster image with the same building annotations. However, in the post-disaster building annotations, each building has a damage level of the following: undamaged, damage, major damage, destroyed, and * unclassified*. The dataset consists of train, tier3, test, and hold subsets. Each subset has an images folder containing pre and post-disaster images stored as 1024*1024 PNGs and a folder named labels containing building annotations and damage labels in JSON format. Some of the post-imagery is slightly shifted from their corresponding pre-disaster image. Also, the dataset has different ground sample distances. We used the train and tier3 subsets for training, the test subset for validation, and the hold subset for testing. The dataset is highly unbalanced in multiple aspects. The buildings with the undamaged label are far more than buildings with other damage types. The number of images varies a lot between different disasters; the same is true for the number of building annotations in each disaster.

Folder Structure
dataset
β”œβ”€β”€ test
|   └── ... (similar to train)
β”œβ”€β”€ hold
|   └── ... (similar to train)
β”œβ”€β”€ tier3
|   └── ... (similar to train)
└── train
    β”œβ”€β”€ images
    β”‚   β”œβ”€β”€ ...
    β”‚   β”œβ”€β”€ {disaster}_{id}_post_disaster.png
    β”‚   └── {disaster}_{id}_pre_disaster.png
    β”œβ”€β”€ labels
    β”‚   β”œβ”€β”€ ...
    β”‚   β”œβ”€β”€ {disaster}_{id}_post_disaster.json
    β”‚   └── {disaster}_{id}_pre_disaster.json
    └── targets
        β”œβ”€β”€ ...
        β”œβ”€β”€ {disaster}_{id}_post_disaster_target.png
        └── {disaster}_{id}_pre_disaster_target.png
Example Usage
from pathlib import Path
from metadamagenet.dataset import LocalizationDataset, ClassificationDataset

dataset = LocalizationDataset(Path('/path/to/dataset/train'))
dataset = ClassificationDataset([Path('/path/to/dataset/train'), Path('/path/to/dataset/tier3')])

an example of data

Problem Definition

We can convert these building annotations (polygons) to a binary mask. We can also convert the damage levels to values 1-4 and use them as the value for all the pixels in their corresponding building, forming a semantic segmentation mask. Thus, we define the building localization task as predicting each pixel's value being zero or non-zero. We also define the damage classification task as predicting the exact value of pixels within each building. We consider the label of an unclassified building as undamaged, as it is the most common label by far in the dataset.

Data Augmentations

Example Usage
import torch
from metadamagenet.augment import Random, VFlip, Rotate90, Shift, RotateAndScale, BestCrop, OneOf, RGBShift, HSVShift,\
    Clahe, GaussianNoise, Blur, Saturation, Brightness, Contrast, ElasticTransform

transform = torch.nn.Sequential(
    Random(VFlip(), p=0.5),
    Random(Rotate90(), p=0.95),
    Random(Shift(y=(.2, .8), x=(.2, .8)), p=.1),
    Random(RotateAndScale(center_y=(0.3, 0.7), center_x=(0.3, 0.7), angle=(-10., 10.), scale=(.9, 1.1)), p=0.1),
    BestCrop(samples=5, dsize=(512, 512), size_range=(0.45, 0.55)),
    Random(RGBShift().only_on('img'), p=0.01),
    Random(HSVShift().only_on('img'), p=0.01),
    OneOf(
        (OneOf(
            (Clahe().only_on('img'), 0.01),
            (GaussianNoise().only_on('img'), 0.01),
            (Blur().only_on('img'), 0.01)), 0.01),
        (OneOf(
            (Saturation().only_on('img'), 0.01),
            (Brightness().only_on('img'), 0.01),
            (Contrast().only_on('img'), 0.01)), 0.01)
    ),
    Random(ElasticTransform(), p=0.001)
)

inputs = {
    'img': torch.rand(3, 3, 100, 100),
    'msk': torch.randint(low=0, high=2, size=(3, 100, 100))
}
outputs = transform(inputs)

Data Augmentation techniques help generate new valid samples from the dataset. Hence, they provide us with more data, help the model train faster, and prevent overfitting. Data Augmentation is vastly used in training computer vision tasks, from image classification to instance segmentation. In most cases, data augmentation is done randomly. This randomness means that the augmentation is not done on some of the original samples and it has some random parameters. Most libraries used for augmentation, like Open-CV 4, do not support image-batch transforms and only perform transforms on the CPU. Kornia 5 6 is an open-source differentiable computer vision library for PyTorch7; it supports image-batch transforms and performs these transforms on GPU. We used Kornia and added some parts to it to suit our project requirements.

We created a version of each image transformation in order for it to support our needs. Its input is multiple batches of images, and each batch has a name. For example, an input contains a batch of images and a batch of corresponding segmentation masks. In some transformations like resize, the same parameters (in this case, scale) should be used for transforming both the images and the segmentation masks. In some transformations, like channel shift, the transformation should not be done on the segmentation masks. Another requirement is that the transformation parameters can differ for each image and its corresponding mask in the batch. Furthermore, a random augmentation should generate different transformation parameters for each image in the batch. Moreover, it should be considered that the transformation does not apply to some images in the batch. Our version of each transformation meets these requirements.

Methodology

Example Usage
from metadamagenet.models import Localizer
from metadamagenet.models.unet import EfficientUnetB0


# define localizer of unet
class EfficientUnetB0Localizer(Localizer[EfficientUnetB0]): pass


# load pretrained model
pretrained_model = EfficientUnetB0Localizer.from_pretrained(version='00', seed=0)

# load an empty model
empty_model = EfficientUnetB0Localizer()

# load a model from pretrained unet
unet: EfficientUnetB0  # some pretrained unet
model_with_pretrained_unet = EfficientUnetB0Localizer(unet)

# load an empty unet
empty_unet = EfficientUnetB0()

# load a unet with pretrained backbone
unet_with_pretrained_backbone = EfficientUnetB0(pretrained_backbone=True)

General Architecture

As shown in the figure below, building-localization models consist of a feature extractor (a U-net 8 or a SegFormer 9) and a classifier module 1. The feature extractor extracts helpful features from the input image; then, the classifier module predicts a value of 0 or 1 for each pixel, indicating whether this pixel belongs to a building or not. The feature extractor module extracts the same features from pre-disaster and post-disaster images in the classification models. In these models, the classifier module predicts a class between 0 and 4 for each pixel. The value 0 indicates that this pixel belongs to no building; values 1-4 mean that this pixel belongs to a building and show the damage level in that pixel. The classifier module learns a distance function between pre-disaster and post-disaster images because the damage level of each facility can be determined by comparing it in the pre- and post-disaster images. In many samples, the post-disaster image has a minor shift compared to the pre-disaster image. However, the segmentation masks are created based on the buildings' location in the pre-disaster image. This shift is an issue the model has to overcome. In our models, feature-extracting weights are shared between the two images. This helps the model to detect the shift or nadir difference. For models that share a joint feature extractor (like SegFormerB0 Classifier and SegFormerB0 Localizer), we can initialize the feature extractor module in the classification model with the localization model's feature extractor. Since we do not use the localization model directly for damage assessment, training the localization model can be seen as a pre-training stage for the classification model.

General Architecture

U-Models

Some models in this project use a U-net 8 module as the feature extractor and a superficial 2D Convolutional Layer as the classifier. We call them U-models. Their feature extractor module is a U-net 8 with five encoder and five decoder modules. Encoder modules are usually a part of a general feature extractor like Resnet-34 10. In the forward pass of each image through each encoder module, the number of channels may or may not change. Still, the height and width of the image are divided by two. Usually, the five encoder modules combined include all layers of a general feature extractor model (like Resnet34 10) except for the classification layer. Each decoder module combines the output of the previous decoder module and the respective encoder module. For example, encoder module 2 combines the output of decoder module 1 and encoder module 3. They form a U-like structure, as shown in the figure below.

Unet

Decoder Modules

There are two variants of decoder modules: The Standard decoder module and the SCSE 11 decoder module. The Standard decoder module applies a 2D convolution and a Relu activation to the input from the previous decoder. Then, it concatenates the result with the input from the respective encoder module and applies another 2D convolution, and ReLU activation. SCSE decoder module works the same way, but in the last step, it uses a "Concurrent Spatial and Channel Squeeze & Excitation" 11 module on the result. This SCSE module is supposed to help the model focus on the image's more critical regions and channels. Decoder modules in xview2 first place solution 1 don't use batch normalization between the convolution and the activation. We added this layer to the decoder modules to prevent gradient exploding and to make these modules more stable.

Decoder Modules

Backbone

We pick encoder modules of U-net modules from a general feature extractor model called The Backbone Network. The choice of the backbone network is the most crucial point in the performance of a U-net model. Plus, most of the parameters of a U-net model are of its backbone network. Thus, the choice of the backbone network significantly impacts its size and performance. xview2 first place solution 1 used Resnet34 10, Dual Path Network 92 12, SeResnext50 (32x4d) 13, and SeNet154 14 as the backbone network. We used EfficientNet B0 and EfficientNet B4 15 (both standard and Wide-SE versions) as the backbone network, creating new U-models called Efficient-Unets. EfficientNets 15 have shown excellent results on the ImageNet 16 dataset, so they are good feature extractors. They are also relatively small in size. These two features make them perfect choices for a backbone network.

We listed all the used models and their attributes in the table below.

model #params Batch Normalization DecoderType
name backbone
Resnet34Unet resnet_34 25,728,112 No Standard
SeResnext50Unet se_resnext50_32x4d 34,559,728 No Standard
Dpn92Unet dpn_92 47,408,735 No SCSE - concat
SeNet154Unet senet_154 124,874,656 No Standard
EfficientUnetB0 efficientnet_b0 6,884,876 Yes Standard
EfficientUnetB0SCSE 6,903,860 Yes SCSE - no concat
EfficientUnetWideSEB0 efficientnet_widese_b0 10,020,176 Yes Standard
EfficientUnetB4 efficientnet_b0 20,573,144 Yes Standard
EfficientUnetB4SCSE 20,592,128 Yes SCSE- no concat
SegFormer segformer_512*512_ade 3,714,401

Meta-Learning

In meta-learning, a general problem, such as classifying different images (in the ImageNet dataset) or different letters (in the Omniglot 17 dataset), is seen as a distribution of tasks. In this approach, tasks are generally the same problem (like classifying letters) but vary in some parameters (like the script letters belong to). We can take a similar approach to our problem. We can view building detection and damage level classification as the general problem and take the disaster type (like a flood, hurricane, or wildfire) and the environment of the disaster (like a desert, forest, or urban area) as the varying factors. In distance-learning methods, the distance function returns a distance between the query sample and each class's sample. Then, the query sample is classified into the class with the minimum distance. These methods are helpful when we have a high number of classes. However, in our case, the number of classes is fixed. Thus, we used a model-agnostic approach. Model agnostic meta-learning 18 algorithms find a set of parameters for the model that can be adapted to a new task by training with very few samples. We used the MAML 18 algorithm and considered every different disaster a separate task. Since the MAML algorithm consumes lots of memory, and the consumed memory is relative to the model size, we have used models based on EfficientUnetB0 and only trained it for the building localization task.

Since the MAML algorithm trains the model much slower than regular training, and we had limited time to train our models, the results weren't satisfactory. We trained EfficientUnetB0-Localizer with MAML with support shots equal to one or five and query shots equal to two or ten. Other training hyperparameters and evaluation results are available in the results section. We utilized the Higher 19 library to implement the MAML algorithm.

Example Usage
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from metadamagenet.dataset import discover_directory, group_by_disasters, MetaDataLoader, LocalizationDataset, ImageData
from metadamagenet.metrics import xview2
from metadamagenet.losses import BinaryDiceLoss
from metadamagenet.runner import MetaTrainer, MetaValidationInTrainingParams
from metadamagenet.models import BaseModel

dataset: list[ImageData] = discover_directory
tasks: list[tuple[str, list[ImageData]]] = group_by_disasters(dataset)

train = MetaDataLoader(LocalizationDataset, tasks[:-2], task_set_size=17, support_shots=4, query_shots=8, batch_size=1)
test = MetaDataLoader(LocalizationDataset, tasks[-2:], task_set_size=2, support_shots=4, query_shots=8, batch_size=1)

model: BaseModel
version: str
seed: int
meta_optimizer: Optimizer
inner_optimizer: Optimizer
lr_scheduler: MultiStepLR
MetaTrainer(
    model,
    version,
    seed,
    train,
    nn.Identity(),
    meta_optimizer,
    inner_optimizer,
    lr_scheduler,
    BinaryDiceLoss(),
    epochs=50,
    n_inner_iter=5,
    score=xview2.localization_score,
    validation_params=MetaValidationInTrainingParams(
        meta_dataloader=test,
        interval=1,
        transform=None,
    )
).run()

Vision Transformer

In recent years, vision transformers 20 have achieved state-of-the-art results in many computer vision tasks, including semantic segmentation. SegFormer 9 is a model designed for efficient semantic segmentation, and it is based on vision transformers. SegFormer is available in different sizes. We only used the smallest size, named SegFormerB0. The SegFormer model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head. In contrast to U-nets, SegFormer models have constant input and output sizes. So, the inputs and outputs of the model should be interpolated to the correct size. For the localization task, the input image goes through SegFormer, and its outputs go through a SegFormer decode head. However, for the classification task, pre and post-disaster go through the same Segformer model. Next, their outputs are concatenated channel-wise and then go through a modified SegfFormer decode head. The modification is to double the number of channels for the MLP modules. Of course, both outputs can be merged in successive layers, which decreases the distance function complexity. These other versions of the modified decode head can be created and tested in the future. Moreover, one can experiment with changing the size of the SegFormer input and SegFormer model size.

Training Setup

We trained some models with multiple random seeds (multiple folds) to ensure they have low variance and consistent scores. We trained Localization models only on pre-disaster images because post-disaster images added noise to the data; we used post-disaster images in sporadic cases as additional augmentation. We initialized each classification model's feature extractor using weights from the corresponding localization model and fold number. In training both classification and localization models, no weights were frozen. Since the dataset is unbalanced, we use weighted losses with weights relative to the inverse of each class's sample count. We applied morphological dilation with a 5*5 kernel to classification masks as an augmentation. Dilated masks made predictions bolder. We also used PyTorch 7 amp for FP-16 21 training.

Loss Functions

Example Usage
from metadamagenet.losses import WeightedSum, BinaryDiceLoss, BinaryFocalLoss

WeightedSum(
    (BinaryDiceLoss(), 1.0),
    (BinaryFocalLoss(alpha=0.7, gamma=2., reduction='mean'), 6.0)
)

Both Building Localization and Damage Classification are semantic segmentation tasks. Because, in both problems, the model's purpose is classification at pixel level. We have used a combination of multiple segmentation losses for all models. In 22, you can find a comprehensive comparison between popular loss functions for semantic segmentation.

Focal and Dice Loss are loss functions used in the training localization models. For Classification models, we tried channel-wise-weighted versions of Focal, Dice, and Cross-entropy Loss.

Focal Loss23

$$ FL(p_t) = -\alpha_t(1- p_t)\gamma log(p_t). $$

Focal Loss's usage is to make the model focus on hard-to-classify examples by increasing their loss value. We used it because the target distribution was highly skewed. In the building localization task, the number of pixels containing buildings was far less than the background pixels. In the damage classification task, undamaged building samples formed most of the total samples, too.

Dice Loss24

$$ Dice\space Loss(p,t) = 1 - dice(p,t) $$

Where $dice$, $p$ and $t$ stand for dice coefficient, predictions and target values respectively.

$$ dice(A,B) = 2\frac{ A\cap B}{A + B} $$

Dice loss is calculated globally over each mini-batch. For multiclass cases, the loss value of each class (channel) is calculated individually, and their average is used as the final loss. Two activation functions can be applied to model outputs before calculating dice loss: sigmoid and softmax. Softmax makes the denominator of the final loss function constant and thus has less effect on the model's training, though it makes better sense.

Cross Entropy Loss25

$$ -\sum_{c=1}^My_{o,c}\log(p_{o,c}) $$

Since we used sigmoid-dice-loss for multiclass damage classification, cross-entropy loss helped the model assign only one class to each pixel. It solely is a good loss function for semantic segmentation tasks.

Evaluation

Example Usage
from torch import Tensor
from metadamagenet.metrics import DamageLocalizationMetric, DamageClassificationMetric

evaluator = 0.2 * DamageLocalizationMetric() + 0.8 * DamageClassificationMetric()

preds: Tensor
targets: Tensor
score = evaluator(preds, targets)

One of the most popular evaluation metrics for classifiers is the F1-score because it accounts for precision and recall simultaneously. The macro version of the F1-score is a good evaluation measure for imbalanced datasets. The xview2-scoring repository describes what variation of F1-score to use for this problem's scoring. We adapted their evaluation metrics. However, we implemented these metrics as a metric in the Torchmetrics 26 library. It performs better than computing metrics in NumPy 27 and provides an easy-to-use API. The dice score is a set similarity measure that equals the F1-score.

$$ Dice(P,Q) = 2. \frac{P \cap Q}{P+Q} $$

$$ F1(P,Q) = \frac{2TP}{2TP + FP + FN} $$

Localization Models Scoring

The localization score is defined as a globally calculated binary f1-score. Sample-wise calculation means calculating the score on each sample (image) and then averaging sample scores to get the final score. In global calculation, we use the sum of true positives, true negatives, false positives, and false negatives across all samples to calculate the metric.

The localization score is a binary f1-score, which means class zero (no-building/background) is considered negative, and class one (building) is considered positive. Since we only care about detecting buildings from the background, micro-average is applied too.

Classification Models Scoring

The classification score consists of a weighted sum of 2 scores: the localization score and the damage classification score. Classification models a label of zero to four for each pixel, indicating no-building, no damage, minor damage, major damage, and destroyed, respectively. Since one to four label values show that a specific pixel belongs to a building, we calculate the localization score after converting all values above zero to one. This score determines how good the model is at segmenting buildings. We define the damage classification score as the harmonic mean of the globally computed f1-score for each class from one to four. We calculate the f1-score of each class separately, then use their harmonic mean to give each damage level equal importance. Here, we prefer the harmonic mean to the arithmetic mean because different classes do not have equal support. We compute the damage classification score only on the pixels that have one to four label values in reality. This way, we remove the effect of the models' localization performance from the damage classification score. Hence, these two metrics represent the models' performance in two disparate aspects.

$$ score = 0.3 \times F1_{LOC} + 0.7 \times F1_{DC} $$

$$ F1_{DC} = 4/(\frac{1}{F1_1 + \epsilon} + \frac{1}{F1_2 + \epsilon} + \frac{1}{F1_3 + \epsilon} + \frac{1}{F1_4 + \epsilon}) $$

Test-Time Augment

Test-Time Augment

While validating a model, we give each piece (or mini-batch) of data to the model and compute a score by comparing the model output and the correct labels. Test-time augment is a technique to enhance the accuracy of the predictions by eliminating the model's bias. For each sample, we use reversible augmentations to generate multiple "transformed samples". The predicted label for the original sample computes as the average of the predicted labels for the " transformed samples". For example, we generate the transformed samples by rotating the original image by 0, 90, 180, and 270 degrees clockwise. Then, we get the model predictions for these transformed samples. Afterward, we rotate the predicted masks 0, 90, 180, and 270 degrees counterclockwise and average them. Their average counts as the model's prediction for the original sample. Using this technique, we eliminate the model's bias of rotation. By reversible augmentation, we mean that no information should be lost during the process of generating "transformed samples" and aggregating their results. For example, in the case of semantic segmentation, shifting an image does not count as a reversible augmentation because it loses some part of the image. However, this technique usually does not improve the performance of well-trained models much. Because their bias of a simple thing like rotation is tiny. The same was true for our models when we used flipping and 90-degree rotation as test-time augmentation.

Example Usage
from metadamagenet.models import FourFlips, FourRotations, BaseModel

model: BaseModel
model_using_test_time_augment = FourRotations(model)

Results

Using pre-trained feature extractors from localization models allowed classification models to train much faster and have higher scores. Using dilated masks improved accuracy around borders and helped with shifts and different nadirs. The model's classifier module determines each pixel's value based on a distance function between the extracted features from the pre- and post-disaster images. In U-models, the classifier module is a 2D convolution, but in SegFormer models, it is a SegFormer decoder head. Hence, U-models learn a much simpler distance function than SegFormer models; the simplicity of the distance function helps them not to overfit but also prevents them from learning some sophisticated patterns. In the end, SegFormer models train much faster before overfitting on the training data, but U-models slowly reach almost the same score. EfficientUnet localization models have shown that they train better without using focal loss. Softmax dice loss does not perform well in the damage classification model's training. A combination of sigmoid dice loss for each class (channel), and cross-entropy loss gives the best results in the training of a classification model. The effect of SCSE in decoder modules and Wide-SE in Encoder Modules of a U-net is very limited; these variations of EfficientUnets performed almost the same as the standard version.

complete results are available at results.md

Discussion and Conclusion

Detecting buildings and their damage level by artificial intelligence can improve rescue operations' speed and efficiency after natural disasters. Solving this problem can identify the area of damage on a large scale and prioritize the areas that have been the most affected and damaged by a disaster in rescue operations. We tested different decoders in the U-net modules and utilized different variations of efficient-net as the backbone in our model. Additionally, we fine-tuned SegFormer for our specific task. The result was models with fewer parameters (approximately three million) that performed much better than the previous models (damage classification score=0.77). Due to the fewer parameters, these models have a shorter training and inference time. Therefore, they can be trained and used faster and easily fine-tuned for new and different natural disasters. Considering damage classification and building localization in each natural disaster as a separate task, we utilized MAML and trained models that can be adapted to a new natural disaster using only a few brand-new samples. These models do not have satisfactory performance, but we hope to build better models of this type in the future.

Future Ideas

The decoder’s number of channels can be looked at as a hyper-parameter which can be changed and tuned. Additionally, we can analyze the effect of the size of the backbone in the efficient U-net by trying Efficientnet b5 or b7 as the backbone. The layer in which the embedding of the pre- and post-disaster images get concatenated dictates the complexity of the distance function in the classifier. This effect can also be tested and analyzed. Log-cosh-dice and Focal-Travesky are two loss functions that have the best performance in the training of segmentation models 22. We can also try training our models with these two loss functions. But in this case, we have to make sure to modify them, so we can assign weights to classes. The low performance of the meta learning model may not be only due to the small number of training epochs or the small number of shots. We can try using first-order MAML like Reptile 28 instead of the original MAML algorithm in the model. These algorithms use less memory, thus, we can test the effects of other factors and hyperparameters faster. Previous research in the realm of meta-learning for semantic segmentation may also help us train a better model for our specific problem. 29 30.

Further Reading

References

Footnotes

  1. πŸ”— Xview2 First Place Solution ↩ ↩2 ↩3 ↩4

  2. πŸ”— Competition and Dataset: Xview2 org. ↩

  3. πŸ“„ xBD: A Dataset for Assessing Building Damage ↩

  4. πŸ“„ open cv ↩

  5. πŸ“„ Kornia: an Open Source Differentiable Computer Vision Library for PyTorch ↩

  6. πŸ“„ A survey on Kornia: an Open Source Differentiable Computer Vision Library for PyTorch ↩

  7. πŸ“„ Automatic differentiation in PyTorch ↩ ↩2

  8. πŸ“„ U-Net: Convolutional Networks for Biomedical Image Segmentation ↩ ↩2 ↩3

  9. πŸ“„ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers ↩ ↩2

  10. πŸ“„ Deep Residual Learning for Image Recognition ↩ ↩2 ↩3

  11. πŸ“„ Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks ↩ ↩2

  12. πŸ“„ Dual Path Networks ↩

  13. πŸ“„ Aggregated Residual Transformations for Deep Neural Networks ↩

  14. πŸ“„ Recalibrating Fully Convolutional Networks with Spatial and Channel 'Squeeze & Excitation' Blocks ↩

  15. πŸ“„ EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks ↩ ↩2

  16. πŸ“„ ImageNet: A large-scale hierarchical image database ↩

  17. πŸ“„ The Omniglot challenge: a 3-year progress report ↩

  18. πŸ“„ Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks ↩ ↩2

  19. πŸ“„ Generalized Inner Loop Meta-Learning ↩

  20. πŸ“„ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale ↩

  21. πŸ“„ AMPT-GA: Automatic Mixed Precision Floating Point Tuning for GPU Applications ↩

  22. πŸ“„ A survey of loss functions for semantic segmentation ↩ ↩2

  23. πŸ“„ Focal Loss for Dense Object Detection ↩

  24. πŸ“„ Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations ↩

  25. πŸ“„ Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels ↩

  26. πŸ”— Machine learning metrics for distributed, scalable PyTorch applications. ↩

  27. πŸ“„ Array programming with NumPy ↩

  28. πŸ“„ On First-Order Meta-Learning Algorithms ↩

  29. πŸ“„ Meta-seg: A survey of meta-learning for image segmentation ↩

  30. πŸ“„ Meta-Learning Initializations for Image Segmentation ↩