-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Huggingface model wrapper update #2300
Huggingface model wrapper update #2300
Conversation
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
Codecov Report
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## dev_1.17.0 #2300 +/- ##
==============================================
- Coverage 85.60% 83.84% -1.76%
==============================================
Files 324 324
Lines 29326 29982 +656
Branches 5407 5538 +131
==============================================
+ Hits 25104 25138 +34
- Misses 2840 3448 +608
- Partials 1382 1396 +14
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @GiulioZizzo Thank you very much for your pull request! Looks good to me.
Description
Currently, the HF estimator expects a HF model that is on the cpu: if a user moves the model to the GPU before giving it to the estimator there will be a crash to to a mismatch between the tensors created for the forward hooks in the HF model wrapper and the model. This small update ensures that the correct device is used between model and hook inputs.
Fixes # TO CREATE
Type of change
Please check all relevant options.
Testing
As the CI tests are CPU limited we have tested this via updating the notebook to shift the models to GPU devices in a few places before feeding to the estimator.
Test Configuration:
Checklist