Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[perf] Fused linear : small FW cleanup and much better perfs #283

Merged
merged 1 commit into from
Apr 26, 2022

Conversation

blefaudeux
Copy link
Contributor

What does this PR do?

Improves the perfs for the fused linear, mostly the forward pass, but measurable effects all around. The BW pass could be greatly improved, tons of perf on the table, I'll try to work a bit on that in the coming days. This would be applicable for MLP and MHA projection, if we end up writing a dedicated self-attention projection kernel (the op is wx+b, same as fused linear without the activation)

Before submitting

  • [🙃] Did you have fun?
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 26, 2022
@blefaudeux blefaudeux marked this pull request as draft April 26, 2022 05:14
@blefaudeux blefaudeux force-pushed the minor_triton_fused_linear_cleanup branch 2 times, most recently from 0904eb1 to cbc97eb Compare April 26, 2022 05:24
@blefaudeux
Copy link
Contributor Author

There's a lot of speed left in the BW, with a fresher pair of eyes it's obvious. This is a very small PR which already brings some decent speed, all from the FW pass. Fist bump @dianaml0, I know you were hoping for more perf there recently. I looked into that while contemplating a MHA - self attention dedicated projection kernel, which would be very similar (without the activation, and dispatching into 3 buffers)

@blefaudeux blefaudeux marked this pull request as ready for review April 26, 2022 05:27
@blefaudeux blefaudeux changed the title [perf] Fused linear : small cleanup and much better perfs [perf] Fused linear : small FW cleanup and much better perfs Apr 26, 2022
@blefaudeux
Copy link
Contributor Author

in a nutshell..

gnome-shell-screenshot-b3jc2b

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Apr 26, 2022

Inference is even nicer (this PR only touches this part), due to the fusion adding the bias is completely free

gnome-shell-screenshot-f0a79d

@blefaudeux blefaudeux force-pushed the minor_triton_fused_linear_cleanup branch from cbc97eb to 3f168e1 Compare April 26, 2022 06:00
@blefaudeux
Copy link
Contributor Author

Ideally @dianaml0 the backward pass should be improved, and the FusedMLP could revert to using this instead of the not-super-impactful fused dropout

@codecov-commenter
Copy link

Codecov Report

Merging #283 (3f168e1) into main (a0fb375) will not change coverage.
The diff coverage is n/a.

@@           Coverage Diff           @@
##             main     #283   +/-   ##
=======================================
  Coverage   92.72%   92.72%           
=======================================
  Files          61       61           
  Lines        3407     3407           
=======================================
  Hits         3159     3159           
  Misses        248      248           
Flag Coverage Δ
Python 92.72% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a0fb375...3f168e1. Read the comment docs.

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome results!

@blefaudeux blefaudeux merged commit ac94252 into main Apr 26, 2022
@blefaudeux blefaudeux deleted the minor_triton_fused_linear_cleanup branch April 26, 2022 16:04
@moinnadeem
Copy link

moinnadeem commented May 5, 2022

@blefaudeux Hey Ben! I'm curious: did you run this on an end-to-end benchmark? I've been able to recreate the increased FLOPs in a timing measurement, but it performs roughly the same as a PyTorch vanilla linear + GeLU layer on a BERT-Base (110M) on an A100 w/ 80GB VRAM.

Is this expected? I did install xformers from source since the pip package doesn't reflect this PR, and use Triton v2 instead of v1.1.

As always, great work! I expect to be a heavy xformers user :)

@blefaudeux
Copy link
Contributor Author

@blefaudeux Hey Ben! I'm curious: did you run this on an end-to-end benchmark? I've been able to recreate the increased FLOPs in a timing measurement, but it performs roughly the same as a PyTorch vanilla linear + GeLU layer on a BERT-Base (110M) on an A100 w/ 80GB VRAM.

Is this expected? I did install xformers from source since the pip package doesn't reflect this PR, and use Triton v2 instead of v1.1.

As always, great work! I expect to be a heavy xformers user :)

hey @moinnadeem thanks for the check and message ! I've been willing to follow up on that for some time (this branch and a few others), but could not find the time. There are a couple of caveats, some of which I discovered after this PR, I should probably alter the curves:

  • one small caveat is that this is with a much smaller gpu (3080 laptop), (a) the bottlenecks (compute/bandwidth) and (b) the best kernel schedules (all these triton hyperparams) are not the same. I don't have currently access to an A100 unfortunately. In general the benchmarks scripts are also here so that everyone can check the perfs for themselves, which you did and it's perfect

  • one weird caveat is that while working on the follow up branches, I could not repro this curves on the same machine, we investigated this a bit with @ptillet, the original Triton author. I'm still not sure of what happened, I was working on the dev pip packages for triton so maybe that I got a very fast (but probably wrong and fixed down the line) package, I did try quite a few combinations and the pytorch/cuda perf looked normal, it's still a mystery to me

  • one general caveat is that while I'm still thinking that we should be able to improve these layers, which would be impactful since MLP is so prevalent, right now the BW is not well fused and this is what kills the perf. We don't use this layer in the "FusedMLP" because of that. I think that they still have value thanks for handling natively some activations which are evry slow on pytorch (like squared relu), but other than that I would not invest too much in them on the user side right now, they are just not good/impactful enough

  • We have a small E2E benchmark as part of the CI (see), it would be dependent on hardware and dimensions but in general xformers should be expected to be marginally faster but consume a lot less RAM than the matching pytorch (+expose a lot more options of course, that's a big part of the point)

@blefaudeux
Copy link
Contributor Author

#296 added to make sure that I don't forget to update the curves. Ideally I would love to check in a small perf bump at the same time

@blefaudeux
Copy link
Contributor Author

I'm also beginning to wonder whether this is related to pytorch/pytorch#76509

@ngimel
Copy link

ngimel commented May 5, 2022

#76509 landed yesterday, so couldn't affect earlier results. Also, if affects only fp32 matmul ops, not fp16.

@blefaudeux
Copy link
Contributor Author

#76509 landed yesterday, so couldn't affect earlier results. Also, if affects only fp32 matmul ops, not fp16.

fair enough, I was wondering if in some of the curves I was comparing triton (tf32 internally) vs. pytorch (fp32 at some point, then tf32, now back to fp32). I meant to reference the issue to say that I could have tripped over one of these changes (I was not always tracking pytorch dev, although our CI is)

@ngimel
Copy link

ngimel commented May 6, 2022

Pytorch was always tf32 until yesterday it became fp32. But for fp16 inputs it doesn't matter at all, and I see only fp16 curves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants