Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix batch issue when generating features + add sample_weight in deep models #220

Merged
merged 31 commits into from
Sep 3, 2024

Conversation

YanisLalou
Copy link
Collaborator

Before when doing a model.predict_features(X) we were passing all the input in a single batch.
This created CUDA out of memory issues when working with big datasets.
Thus here I tried to mimic the behaviour of skorch models to use batching.

Might not be the best way to fix that issue though, would love your opinion @tgnassou

@YanisLalou YanisLalou changed the title Fix batch issue when generating features in deep models [WIP] Fix batch issue when generating features in deep models Jul 19, 2024
Copy link
Collaborator

@tgnassou tgnassou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if you can not just create a dataloader on your own, without using skorch function. Like that you only create a dataloader with X, iterate on the dataloader and that's it. It seems very long to do something simple, but maybe I'm wrong

skada/deep/base.py Outdated Show resolved Hide resolved
skada/deep/base.py Show resolved Hide resolved
@antoinecollas antoinecollas changed the title [WIP] Fix batch issue when generating features in deep models [WIP] Fix batch issue when generating features + add sample_weight in deep models Aug 11, 2024
Copy link
Collaborator

@antoinecollas antoinecollas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_weight is everywhere but in practice we only use it to reweight the loss, right? IMO, it should used only when calling the loss but maybe I misunderstood something.

If you run coverage run -m pytest -v -s && coverage html && open htmlcov/index.html on your branch and on the main branch, you will see that you have added 24 lines in skada/deep/base.py that are not covered by tests.

skada/deep/base.py Outdated Show resolved Hide resolved
skada/deep/base.py Show resolved Hide resolved
skada/deep/base.py Show resolved Hide resolved
skada/deep/base.py Show resolved Hide resolved
skada/deep/base.py Outdated Show resolved Hide resolved
skada/deep/base.py Outdated Show resolved Hide resolved
@antoinecollas
Copy link
Collaborator

From Skorch FAQ: when X is a dict, its keys are passed as kwargs to forward, thus our forward has to have the arguments 'data' and 'sample_weight'; usually, sample_weight can be ignored here.

Copy link

codecov bot commented Aug 12, 2024

Codecov Report

Attention: Patch coverage is 74.26471% with 35 lines in your changes missing coverage. Please review.

Project coverage is 96.31%. Comparing base (28aaf82) to head (6aeda7e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #220      +/-   ##
==========================================
- Coverage   97.01%   96.31%   -0.70%     
==========================================
  Files          54       54              
  Lines        5429     5486      +57     
==========================================
+ Hits         5267     5284      +17     
- Misses        162      202      +40     

@antoinecollas antoinecollas changed the title [WIP] Fix batch issue when generating features + add sample_weight in deep models [TO_REVIEW] Fix batch issue when generating features + add sample_weight in deep models Aug 13, 2024
@tgnassou
Copy link
Collaborator

Very good PR, the only thing missing is to modify also CDANModule

@tgnassou tgnassou changed the title [TO_REVIEW] Fix batch issue when generating features + add sample_weight in deep models [MRG] Fix batch issue when generating features + add sample_weight in deep models Sep 3, 2024
@tgnassou tgnassou merged commit dccc59e into scikit-adaptation:main Sep 3, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants