-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] FP8 GPT 1.5B Delayed Scaling 2x Slower than BF16 #1297
Comments
Hi @OrenLeung , thanks for the repro! This looks like a bug in how we handle delayed scaling + autocast, let me take a look. |
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // #1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // #1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
thanks again for the report, #1306 should fix this. With that PR on my H100 machine:
|
@vkuzo Thanks for the quick fix! I am guessing you did your benchmark the 500W h100 version? I can confirm the fix using #1306 ! I am seeing the following:
|
Yes, that's correct. |
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // #1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
closing since the fix landed |
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from pytorch#1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // pytorch#1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
Hi Torch Team,
I am currently experimenting with native torch float8 training & comparing it to the Transformer Engine using the delayed scaling recipe on GPT 1.5B at batch=12 seq=1024 on 700W H100 SXM 80G SKU.
I see that fp8 transformer engine provides slight perf include compared to autocast bf16 but unfortunately torchao.float8 is almost 2x slower. I attempted to improve performance by trying to enable fp8 & using bf16 autocast at the same time but unfortunately I ran into
ValueError: All layers must have the same last seen input_dtype, got {torch.float32, torch.bfloat16}
error. enabling fp8 & using bf16 autocast is something that TE does but not sure if it is needed for torchao.Can you provide some guidance on how to improve performance on torchao.float8?
Thanks!
Reprod Script
Dependencies
$ pip list | grep torch pytorch-triton 3.1.0+cf34004b8a torch 2.6.0.dev20241030+cu124 torch-tb-profiler 0.4.3 torchao 0.7.0.dev20241112+cu121
The text was updated successfully, but these errors were encountered: