Skip to content

TF implementation of ResNet architecture on heavily imbalanced SIIM-ISIC melanoma dataset with use of LDAM loss and stratified batch normalization to guard against imbalance problem

Notifications You must be signed in to change notification settings

karurb92/ldam_str_bn

Repository files navigation

ldam_str_bn

Setup


The setup below works on a UNIX like system. Windows should work in a similar fashion. Just give it a quick google.

python3 -m venv <directory name>
source <directory name>/bin/activate
pip install -r requirements.txt

The dataset should be stored in a folder called local_work and all images should reside is a child folder called all_imgs. These names can also be adjusted in the config file. You can read more about the dataset in the corresponding section below.

Datasets (HAM10000)


https://www.kaggle.com/kmader/skin-cancer-mnist-ham10000

With 7 columns : lesion_id, image_id, dx, dx__type, age, sex, localization

ex) [HAM_0000118, ISIC_0027419 ,bkl, histo, 80.0, male, scalp]

Topic & Tasks


When it comes to dealing with heavily imbalanced dataset, we focused on two approaches: Label-distribution-aware loss function(LDAM) and stratified batch normalization.

  • Label-distribution-aware loss function(LDAM)
  • Stratified Batch Normalization
    • First layer of the net is being normalized separately for different stratification classes. For example, if sex and age_mapped are dimensions used for stratification, there will be 6 stratification classes (cartesian of (male,female,unknown) and (<=50, >50)).
    • Each stratification class uses its own set of gammas and betas
    • The underlying idea of stratification is the assumption that for different stratification classes, distributions of labels differ significantly. Therefore, they should be made even before being fed to the network.

We artificially made medical imaging dataset to be highly imbalanced (with different imbalance ratios). strat_data_generator and utils_sc.draw_data() implement this functionality. Then, we implemented stratified batch normalization (models.strat_bn_simplified) within a ResNet model (models.resnet) with use of Label-Distribution-Aware loss function (losses). In the end, we perform unit tests with unittest python module for the loss function, stratified batch normalization and data generator to check if they function correctly.

Challenges


  1. Finding a suitable network architecture
  2. Deciding on what dimensions do we stratify - choice of features and dealing with data transformation.
  3. Building our own data generator and feeding metadata to the net in a customized way.
  4. Implementing stratified batch normalization
    • Understanding the concept and original Tensorflow BN implementation
    • Dealing with parameters in new shapes for both training and non-training modes (i.e. updating/using moving_mean, moving_variance, beta, gamma)
  5. Converting LDAM loss function from PyTorch to Tensorflow
    • Understanding the concept of LDAM in general
    • Dealing with different data structures & methods

Team's contribution


  1. Data Preprocessing - implemented our own data generator strat_data_generator and utils_sc

  2. Implemented LDAM loss in Tensorflow (losses)

  3. Implemented stratified batch normalization with ResNet model (models.strat_bn_simplified, models.resnet)

  4. Unit tests with unittest:

    • LDAM loss - compare both pytorch LDAM loss and tensorflow LDAM loss unit by unit
    • Stratified Batch Normalization - compare two images from different/same stratification classes
    • Data Generator - check if it yields metadata (about stratification classes) correctly

Results


  • Stratified Batch Normalization

    Without LDAM loss

    • Epoch accuracy

    • Epoch losses

    • beta

    • gamma

    • moving_mean

    • moving_variance

  • LDAM Loss

    • Epoch Accuracy
    • Epoch losses

References



Team ‘Weißwürstchen’

Seunghee Jeong seunghee6022@gmail.com

Nick Stracke nick.stracke@web.de

Karol Urbańczyk karurb92@gmail.com


About

TF implementation of ResNet architecture on heavily imbalanced SIIM-ISIC melanoma dataset with use of LDAM loss and stratified batch normalization to guard against imbalance problem

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages