Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Ability to pass the TargetFormatter to use with classification data modules #1131

Closed
daMichaelB opened this issue Jan 21, 2022 · 3 comments · Fixed by #1171
Closed

Ability to pass the TargetFormatter to use with classification data modules #1131

daMichaelB opened this issue Jan 21, 2022 · 3 comments · Fixed by #1171
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@daMichaelB
Copy link
Contributor

❓ Questions and Help

What is your question?

I have a highly imbalanced dataset, where some minority classes are very rare. I put them ONLY into the validation set. I want to validate, if the model can classify them not to be in the majority class.

The Datamodule was created with:

        datamodule = ImageClassificationData.from_data_frame(
            "file", "label",
            train_images_root=...,
            val_images_root=....,
            test_images_root=...,
            train_data_frame=train_df,
            val_data_frame=valid_df,
            test_data_frame=test_df,
            ...
        )

As i understood, i can create the ImageClassifier with the number of ALL classes:

        model = ImageClassifier(backbone=...,
                                num_classes=self.num_classes,
                                pretrained=...)

However my training crashes at the beginning with Validation sanity check:. Tracelog:

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/input.py", line 317, in __getitem__
    return self._call_load_sample(self.data[index])
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/input.py", line 236, in _call_load_sample
    return load_sample(copy(sample))
  File "/usr/local/lib/python3.8/dist-packages/flash/image/classification/input.py", line 49, in load_sample
    sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/io/classification_input.py", line 79, in format_target
    return self.target_formatter(target)
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/utilities/classification.py", line 163, in __call__
    return self.format(target)
  File "/usr/local/lib/python3.8/dist-packages/flash/core/data/utilities/classification.py", line 182, in format
    return self.label_to_idx[(target[0] if not isinstance(target, str) else target).strip()]
KeyError: '14'

I found that Label 14 is in the validation set but not in the training set.

Question

Is there a way to train on a subset of the classes but validate on all classes ?

What have you tried?

I have no idea how to workaround this...

What's your environment?

  • OS: [e.g. iOS, Linux, Win] Ubuntu 18. in docker
  • Packaging [e.g. pip, conda] pip
  • Version [e.g. 0.5.2.1] 0.6.0
@daMichaelB daMichaelB added the question Further information is requested label Jan 21, 2022
@daMichaelB daMichaelB changed the title How can one train a ImageClassificationData if Class is in Validation-Set but not in Train-Set? How can one train a ImageClassifier if Class is in Validation-Set but not in Train-Set? Jan 21, 2022
@ethanwharris
Copy link
Collaborator

Hey @daMichaelB Thanks for reporting this! This is something we do now have support for internally but don't yet expose to the user. All labels / num classes etc. for classification problems are handled by a TargetFormatter object (see the API references here: https://lightning-flash.readthedocs.io/en/latest/api/data.html#flash-core-data-utilities-classification ).

These objects are usually inferred from the training data, but in cases where that inference is not possible (e.g. where can't efficiently get a list of all targets) we have begun to expose this object. So you could have for example:

 datamodule = ImageClassificationData.from_data_frame(
    ...,
    target_formatter = MultiLabelTargetFormatter(labels=["label_1", ..., "label_n"]),
)

Would this API work for you? If so, I can get to work on adding the target_formatter argument to all of our from_* methods 😃

@daMichaelB
Copy link
Contributor Author

Hey @ethanwharris . This would solve a lot of trouble on my side 🎉 ! I think that would be a great feature for dealing with imbalanced datasets!

Thank you for the suggestion and let me know if i can help with testing it!

@ethanwharris ethanwharris added enhancement New feature or request and removed question Further information is requested labels Jan 28, 2022
@ethanwharris ethanwharris self-assigned this Jan 28, 2022
@ethanwharris ethanwharris added this to the v0.7 milestone Jan 28, 2022
@ethanwharris ethanwharris changed the title How can one train a ImageClassifier if Class is in Validation-Set but not in Train-Set? Ability to pass the TargetFormatter to use with classification data modules Jan 28, 2022
@daMichaelB
Copy link
Contributor Author

Thank you for the great support and implementation 👍

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
2 participants