Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How the backward overflow test works #650

Closed
jayz0123 opened this issue Nov 2, 2023 · 6 comments
Closed

How the backward overflow test works #650

jayz0123 opened this issue Nov 2, 2023 · 6 comments

Comments

@jayz0123
Copy link

jayz0123 commented Nov 2, 2023

Dear Team,

I am currently developing FA2 for AMD platforms and when I investigate on the unit tests, I saw there is a new tests called backward overflow.

After I trace through the test script, I saw the values of q_pt.grad and k_pt.grad are both all zeros after out.backward(). Is that supposed to be so? What is the main purpose for this test?

@jayz0123
Copy link
Author

jayz0123 commented Nov 2, 2023

Also, in this test_flash_attn_bwd_overflow, why multiplying q by 5 and kv by 3? what is the purpose?

@jayz0123
Copy link
Author

jayz0123 commented Nov 2, 2023

I also wonder if you are using RTN or RTZ for the unit tests to all pass?

@tridao
Copy link
Contributor

tridao commented Nov 2, 2023

These tests were added because of some bug reports on gradients being NaN. Multiply q and kv by some amount (3 or 5) is just to potentially trigger the logits being larger than fp16 limit that could lead to NaN without proper masking.

I'm not sure if we're using RTN or RTZ, it shouldn't matter too much. If the results are close you can increase the tolerance a bit to make it pass.

@jayz0123
Copy link
Author

jayz0123 commented Nov 7, 2023

Hi dao. Thanks for clarifying! We are doing upstream integration from your version v2.0.4. In our implementation, we need to build with RTN for running UT for higher precision, and build with RTZ for benchmarks because of better performance. I am just curious how you can reach the best performance and pass all unit tests at the same time.

@tridao
Copy link
Contributor

tridao commented Nov 7, 2023

Are you referring to the fp32 -> fp16/bf16 conversion of the forward pass output? Or fp32 -> fp16/bf16 conversion of dQ, dK, dV? Or is there another conversion?
Tbh I'm not sure if I'm using RTN or RTZ, whatever is the default in Cutlass when doing fp32 -> fp16/bf16 conversion.

@tridao
Copy link
Contributor

tridao commented Nov 7, 2023

I think we're using RTN since that's what cutlass conversion function calls.

@jayz0123 jayz0123 closed this as completed Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants