Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(triton): InplaceNorm + InstanceNorm #50

Closed
ClashLuke opened this issue Oct 28, 2021 · 1 comment · Fixed by #53
Closed

feat(triton): InplaceNorm + InstanceNorm #50

ClashLuke opened this issue Oct 28, 2021 · 1 comment · Fixed by #53

Comments

@ClashLuke
Copy link
Contributor

I'd love to run LayerNorm in place and ideally also add InstanceNorm (by extracting the core normalization from LayerNorm) as HomebrewNLP is currently using a slow PyTorch-level implementation with a correct backward pass.

While we're at it, optionally fusing GLU and GLUv2 (gelu(f(x)) * g(x) + gelu(h(x))) with various activation functions and normalization might give another speed boost.

To add this myself, I'd need to fully understand triton's pointers and how to access the output instead of input in your LayerNorm implementation. Could you help me with that? or would you instead implement this yourself? Is this even in the scope of xformers?

@blefaudeux
Copy link
Contributor

that's a great question ! definitely open for contributions, and this looks very reasonable, there's a good chance that Triton gives something a lot faster than pytorch there. In terms of scope I think that it's very much ok, as xformers to me is also an optimized parts zoo (with some automatic builders for them, but that's optional).

Just a couple of caveats to begin with:

  • while some Triton subparts will probably work with a gtx 1xxx (if you're not using the tensor cores basically, no dot product), it's not guaranteed, so having a pytorch fallback is good to keep in mind.

  • we keep all our triton code here, k_xx means kernel (the @jit code), put aside so that it does not distort the code coverage metrics. LayerNorm is there (basically from the tutorial from @ptillet), making the norm parametrizable is a great idea, it's how the fused linear layer works (activations are here), Triton is awesome for that since it will generate a fused kernel on the fly for anything you pass (so this becomes easily extensible)

  • who does what ? PR most certainly welcome if you feel like it, else I can put up a PR which does not change layernorm per say, but makes the norm a @jit function that you can pass in, like it's done for the fused linear layer. This way you could easily experiment with other norms and submit them in PRs. As you prefer !

  • Mini heads up in that for now gelu is a little on the slow side with Triton (80% of pytorch, give or take), could be that this improves over time, and that the combination is still worth it. In any case it would be easy enough to try

xwhan pushed a commit to xwhan/xformers that referenced this issue Feb 8, 2022
)

* moving local attention to sparse backend
* better handling of the causal/window sizes implications
* some cleaning up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants