Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Dataset balancing #350

Closed
lillekemiker opened this issue May 30, 2021 · 2 comments
Closed

Dataset balancing #350

lillekemiker opened this issue May 30, 2021 · 2 comments
Labels
enhancement New feature or request help wanted Extra attention is needed won't fix This will not be worked on

Comments

@lillekemiker
Copy link
Contributor

lillekemiker commented May 30, 2021

🚀 Feature

The ability to have a Flash DataModule sample its data in manner that counters any imbalance in the number of labels present for each class.

Motivation

When training a classifier on a dataset where all classes are not represented equally in numbers, the resulting model is likely to develop a bias for the classes that are more represented and become less capable at identified under represented classes. Currently when using a Flash DataModule subclass for loading data, not only does Flash not provide any solutions to this problem, the current implementation of the DataModule actually makes it harder to address the issue than plain pytorch lightning.

Pitch

I suggest a two-step approach to solve this.
Step 1: expose the option to pass a torch.utils.data.Sampler to the pytorch DataLoader being used in the DataModule class. This paves the way for allowing custom sampling strategies in general.
Step 2: add an option for the DataModule to count the number of labels for each class in a dataset, create a set of weights for the sampling probability for each sample depending on the label it contains, and pass a torch.utils.data.WeightedRandomSampler with the given weights to the training set dataloader.

Alternatives

Step 1 described above will on its own at the very least not hinder people from implementing their own sampling logic will still using the convenience of the Flash DataModule.

Additional context

It should be noted that my above suggestion in step 2 for the logic to be used is fairly trivial in cases with one class label per sample. However, in cases with multiple labels per sample, it quickly becomes less simple. I don't think that this should be an argument for not supplying the option for cases when it can be useful, though.
Note also that I have not completely covered the logic needed here. For instance, if we want to retain the concept of epochs, it can no longer mean the time when the entire dataset has been seen once, since allowing each sample to be seen exactly once will mean that the model ends up being trained with the same skewed distribution as is present in the dataset.

@lillekemiker lillekemiker added enhancement New feature or request help wanted Extra attention is needed labels May 30, 2021
@ethanwharris
Copy link
Collaborator

Hi @lillekemiker thanks for the feature request! Step 1 from your proposal definitely sounds like something we can do. In general it would be nice to have a way of customising the data loader creation in the datamodule. Adding sampler arguments might work but would need those arguments to be added to each from_* method, which could get a bit messy. I guess we could have them as properties instead so you could do:

datamodule.train_sampler = ...

We could also create a ClassificationDataModule that has a balance method or similar to auto generate a weighted sampler. Do you have any other ideas about how it could be done? Would you be interested in having a go at this yourself?

Thanks 😃

@stale
Copy link

stale bot commented Aug 3, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Aug 3, 2021
@stale stale bot closed this as completed Aug 12, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants