-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove message of checking apex.amp
module and add tests for features of gradient accumulation/mixed precision training
#46
Conversation
The original propose of that message is to let users know gradient accumulation and mixed precision training is supported but `apex` is required. With an attention brought up by issue davidtvs#45, the following things are confirmed: - Gradient accumulation can still work properly without `apex.amp`. And that's why it would fall back on normal `loss.backward()` when `apex.amp` is not available or `amp.initialize()` wasn't called. - When mixed precision training is required, that is to say model and optimizer are wrapped by `amp.initialize()`, `amp.scale_loss()` will be adopted automatically in current implementation. Therefore, it seems that message of checking `apex.amp` module is not necessary anymore.
This mistake made batch size of every data loader become the default value: 1. Though it does not affect the correctness of all test case, it still needs to be corrected. However, `batch_size` of a `DataLoader` cannot be modified after it is initialized. Therefore, we can only determine it while generating tasks for test, and that's why `batch_size` and `steps` is moved to the signature of `__init__` of each `Task`.
This functionality was not added before, and it made all tests run on CPU even if the pytest argument `--cpu_only` is not specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for another good contribution
tests/test_lr_finder.py
Outdated
@@ -4,6 +4,14 @@ | |||
import task as mod_task | |||
|
|||
|
|||
try: | |||
import apex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line can be changed to from apex import amp
and we can then remove the local imports from the functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'll fix it.
tests/test_lr_finder.py
Outdated
reason="`apex` module and gpu is required to run this test." | ||
) | ||
def test_gradient_accumulation_with_apex_amp(self, mocker): | ||
from apex import amp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line (see comment about import)
tests/test_lr_finder.py
Outdated
) | ||
class TestMixedPrecision: | ||
def test_mixed_precision(self, mocker): | ||
from apex import amp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove line (see comment about import)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to select the proper radio button
/black-check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No linting violations have been found in this PR.
/flake8-lint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lintly has detected code quality issues in this pull request.
tests/test_lr_finder.py
Outdated
@@ -4,6 +4,14 @@ | |||
import task as mod_task | |||
|
|||
|
|||
try: | |||
import apex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
F401: 'apex' imported but unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work. Thanks!
This PR is a fix for issue #45 with some new test cases for those features implemented in PR #9.
A quick summary for this PR:
To enable mixed precision training, please install apex...
is removed. (solved in commit 227fc53)batch_size
was not passed intoDataLoader
, so that those data loaders in test cases were working with the default valuebatch_size=1
before. Though it does not affect the test correctness, it still needs to be corrected. (solved in commit 1c549ec)task.__init__()
. Therefore, all tests were running on CPU even if the pytest argument--cpu_only
is not specified. (solved in commit c854714)Note that there is a new dependency
pytest-mock
added for new test cases.