This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
Dataset balancing #350
Labels
enhancement
New feature or request
help wanted
Extra attention is needed
won't fix
This will not be worked on
🚀 Feature
The ability to have a Flash
DataModule
sample its data in manner that counters any imbalance in the number of labels present for each class.Motivation
When training a classifier on a dataset where all classes are not represented equally in numbers, the resulting model is likely to develop a bias for the classes that are more represented and become less capable at identified under represented classes. Currently when using a Flash
DataModule
subclass for loading data, not only does Flash not provide any solutions to this problem, the current implementation of theDataModule
actually makes it harder to address the issue than plain pytorch lightning.Pitch
I suggest a two-step approach to solve this.
Step 1: expose the option to pass a
torch.utils.data.Sampler
to the pytorchDataLoader
being used in theDataModule
class. This paves the way for allowing custom sampling strategies in general.Step 2: add an option for the
DataModule
to count the number of labels for each class in a dataset, create a set of weights for the sampling probability for each sample depending on the label it contains, and pass atorch.utils.data.WeightedRandomSampler
with the given weights to the training set dataloader.Alternatives
Step 1 described above will on its own at the very least not hinder people from implementing their own sampling logic will still using the convenience of the Flash DataModule.
Additional context
It should be noted that my above suggestion in step 2 for the logic to be used is fairly trivial in cases with one class label per sample. However, in cases with multiple labels per sample, it quickly becomes less simple. I don't think that this should be an argument for not supplying the option for cases when it can be useful, though.
Note also that I have not completely covered the logic needed here. For instance, if we want to retain the concept of epochs, it can no longer mean the time when the entire dataset has been seen once, since allowing each sample to be seen exactly once will mean that the model ends up being trained with the same skewed distribution as is present in the dataset.
The text was updated successfully, but these errors were encountered: