-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mixed precision #40
Conversation
Codecov Report
@@ Coverage Diff @@
## master #40 +/- ##
==========================================
+ Coverage 94.10% 94.12% +0.01%
==========================================
Files 35 35
Lines 2003 2043 +40
==========================================
+ Hits 1885 1923 +38
- Misses 118 120 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@@ -221,7 +223,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, | |||
if can_benchmark and len(model.balance) == 4: | |||
# Assert that words per second is within 3 standard deviations of the average | |||
# of six golden runs | |||
assert wps > 20052.1 - (3 * 359) | |||
assert wps > 27799.2 - (3 * 522.145) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice speedups and memory reduction (below) !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Looks goods overall. I left a few minor comments.
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1 | ||
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1 | ||
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1 | ||
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, i'm curious what tool did you use get these exact numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the values printed by the above four lines of code
m += chunk_idx*chunk_size; | ||
T* v = (T *)tl.addresses[2][tensor_loc]; | ||
float* v = (float *)tl.addresses[2][tensor_loc]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naive question: why do we have types of m and v as float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are momentum and velocity! And right now, we require them to be floats; in the Python code, when they are instantiated, they are always dtype=torch.float32. Next pull request will add in the option for them to be fp16
(adamMode_t) mode, | ||
decay | ||
); | ||
} else { // tl_sz == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add an explicit comment for the setting here, similar to "mix precision case" for the above :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
* Test CPU offload * remove dead code
Before submitting
What does this PR do?
Add mixed precision training to Adam. Update benchmark.
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃