Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend] 3/3 Triton 2 update #272

Merged
merged 5 commits into from
Apr 21, 2022
Merged

[backend] 3/3 Triton 2 update #272

merged 5 commits into from
Apr 21, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Apr 13, 2022

What does this PR do?

Push things a little forward with triton 2. My thinking was to try to land all 3 PRs in one go when this last one is green
Happy to update this PR on that front (and others), this is something quickly wrapped together to try to unlock some other PRs (like #263)

TODOs:

  • Upgrade the CI
  • Fix the kernel syntax to triton 2
  • Fix the timm benchmark / points to layernorm being broken
  • Sparse softmax / bmm unit test failure -> triton2 changed some of the internal computation formats (and pytorch did the same, with a default switch to tf32 for some matmuls I believe). @fmassa the sparse_tensor unit tests were not passing even if almost nothing in the codebase had changed (except for newer CUDA + Triton2), so this PR relaxed some of the parity constraints. Please have a look, we can change things
  • Some speed improvements given triton2 changes

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 13, 2022
@blefaudeux blefaudeux requested review from fmassa and dianaml0 and removed request for fmassa April 13, 2022 05:07
@blefaudeux blefaudeux changed the base branch from main to label_attention_properties April 13, 2022 05:07
@blefaudeux
Copy link
Contributor Author

@blefaudeux blefaudeux requested a review from fmassa April 13, 2022 05:16
@blefaudeux
Copy link
Contributor Author

@dianaml0 just FYI, investigating the layernorm crash, this works fine with cuda 11.6 / Ampere actually, but I can repro with 11.4

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Apr 15, 2022

@dianaml0 just FYI, investigating the layernorm crash, this works fine with cuda 11.6 / Ampere actually, but I can repro with 11.4

some writes where not masked, could be why. I just pushed a small update, works on my machine (tm)

@blefaudeux
Copy link
Contributor Author

@dianaml0 just FYI, investigating the layernorm crash, this works fine with cuda 11.6 / Ampere actually, but I can repro with 11.4

layernorm fixed, as far as I can see, should be a bit faster for very big sizes also

@blefaudeux
Copy link
Contributor Author

@fmassa if you have some cycles at some point, would you mind having a look at the test_sparse_softmax test which does not pass with this branch ? You have a little more context around there (and you could check the changes from this PR, probably a bit rough in this area)

@blefaudeux blefaudeux force-pushed the triton-2 branch 3 times, most recently from 6c97cca to 4327d39 Compare April 17, 2022 15:55
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Apr 17, 2022

@fmassa if you have some cycles at some point, would you mind having a look at the test_sparse_softmax test which does not pass with this branch ? You have a little more context around there (and you could check the changes from this PR, probably a bit rough in this area)

I just improved a bit on the changes in /sparse_tensor, should be ok by now except for the fact that the new blocksparse does not acccept a per pixel mask anymore, so this probably breaks some of this abstraction.

@blefaudeux blefaudeux marked this pull request as draft April 18, 2022 04:32
@blefaudeux blefaudeux force-pushed the label_attention_properties branch from be72b26 to 8113277 Compare April 18, 2022 04:34
@blefaudeux blefaudeux force-pushed the triton-2 branch 4 times, most recently from 6369798 to 9554e3f Compare April 19, 2022 00:23
return torch.float32, 1e-1

# Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU)
torch.backends.cuda.matmul.allow_tf32 = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa this fixed issues that I was seeing with these unit tests on an ampere GPU, which I presume stemmed from the fact that the sparse kernels were fp32 while pytorch defaulted to tf32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow, thanks for spotting this!

One more instance where tf32 is being somewhat harmful. Maybe worth commenting on pytorch/pytorch#67384 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a strange format, range of fp32 but precision of fp16, it's also kind of peculiar that it's really 18bits but named tf32..

# Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations
# TF32 has the precision of fp16 but the range of fp32
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
torch.backends.cuda.matmul.allow_tf32 = True
Copy link
Contributor Author

@blefaudeux blefaudeux Apr 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa this seems to be a better fit following the switch to triton2, which internally moved all tl.dot() operations to tf32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ptillet, just swapping triton 1.1 for 2.dev meant that this test would not pass anymore, as we discussed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM wrt the tests!



def _get_dtype_atol(tensor_type, device: str):
_seed()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was to remove some reproducibility issues in between circleci and my machine..

MODE,
trans_a=TRANS_A,
trans_b=TRANS_B,
device=torch.device("cuda"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton blocksparse op now requires the device to be passed in


# triton result
op = blocksparse_softmax(layout, BLOCK)
op = blocksparse_softmax(layout, BLOCK, device=torch.device("cuda"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton blocksparse softmax now requires the device to be passed in

ty = op(
tx,
scale=scale,
key_padding_mask=kp_mask,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton blocksparse now does not support attn mask or key_padding mask. We can pass a "causal" flag though

@@ -50,7 +50,7 @@ def test_layernorm_parity(shape, amp):
torch.random.manual_seed(0)
X_ = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True)

eps = 1e-5
eps = 1e-4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1/1e-5 overflows in fp16..

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = True
# The underlying triton op does not support per element attention mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see previous PR which introduced these flags, we can now flip them in this case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in principle a BC-breaking change. Should we bother, or should we just follow what Triton does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically we don't have much choice, short of implementing blocksparse ourselves ? Phil arguments were that it was not typically used (short of causal, which is supported), and that the attention mask (additive -> floats) ended up taking a significant amount of space in memory. My take would be that we have a in-house fallback since people can use sparse attention, and else they can stick to the current pip release for some time, so not something that I would do every day (breaking BC), but in that case that was ok ?

# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the
# initial attention setup
# Delayed triton init, to make sure that we get the right device
if not hasattr(self, "sparse_dot_sdd"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton blocksparse need the correct device to be passed, but it could be that it's not the case at construction time (if constructed on CPU then moved), and it's not possible to just default to cuda:0 (would break many multi-gpu cases). So we defer the construction until after the first input tensor comes in

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me.

One other option would be to inherit the .cuda() / .to() methods so that they re-create those objects.

The current approach is fine as is because the those objects don't contain learnable parameters, but if that was the case it would mess up with optimizers / distributed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that's something we discussed with @colehawkins on another PR, one possible issue for me is that some sharded trainers intercept the .to() calls, so this would silently fail in that case. I think that both takes have issues (the delayed init and the .to() overload), the only clean way out that I can think of is to make this attention take a "device" as a construction argument, and put it to the right place from the beginning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux @fmassa I've been using this with both pytorch lightning and the huggingface trainer, and this method (initialization at first forward) is the cleanest way I found that doesn't break any standard workflows or require model initialization workarounds. If we just take the device at construction this breaks the huggingface trainer "natural approach" for single-node, multi-gpu which (1) create the model, then (2) call .to().

One alternative is to initialize with a device at construction, keep that as self.device, and then check that against query.device and possibly re-initialize at the forward pass.

@codecov-commenter
Copy link

codecov-commenter commented Apr 19, 2022

Codecov Report

❗ No coverage uploaded for pull request base (label_attention_properties@8113277). Click here to learn what that means.
The diff coverage is n/a.

@@                      Coverage Diff                      @@
##             label_attention_properties     #272   +/-   ##
=============================================================
  Coverage                              ?   92.69%           
=============================================================
  Files                                 ?       61           
  Lines                                 ?     3393           
  Branches                              ?        0           
=============================================================
  Hits                                  ?     3145           
  Misses                                ?      248           
  Partials                              ?        0           
Flag Coverage Δ
Python 92.69% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8113277...c4c7b5f. Read the comment docs.

@blefaudeux
Copy link
Contributor Author

I think that there's some margin in terms of speed around the fused linear layers, the matmul triton op improved a lot in the last few months and fused linear could probably benefit from some of it. Better keeping it for another PR though

@blefaudeux blefaudeux mentioned this pull request Apr 19, 2022
7 tasks
@colehawkins
Copy link
Contributor

@blefaudeux larger block sizes (at least up to 128) are supported by triton v2.

Also, is any type of recompile option possible for the blocksparse attention? I ran into this issue using tritonv2 blocksparse attention in a distributed environment (specifically huggingface trainer). Since the ops are device-specific, they need to be recreated or there is a device mismatch. It's not too hard to work around with a delayed device-specific model initialization, but I think there are potentially smoother work arounds by inheriting the to function or attempting re-creation in the event of a device mismatch. It's also possible that I'm missing a trivial workaround.

Happy to submit both in a PR to either branch, much more timely than last time (1-2 days).

@blefaudeux
Copy link
Contributor Author

@blefaudeux larger block sizes (at least up to 128) are supported by triton v2.

Also, is any type of recompile option possible for the blocksparse attention? I ran into this issue using tritonv2 blocksparse attention in a distributed environment (specifically huggingface trainer). Since the ops are device-specific, they need to be recreated or there is a device mismatch. It's not too hard to work around with a delayed device-specific model initialization, but I think there are potentially smoother work arounds by inheriting the to function or attempting re-creation in the event of a device mismatch. It's also possible that I'm missing a trivial workaround.

Happy to submit both in a PR to either branch, much more timely than last time (1-2 days).

Oh sure for the block size, I didn't know, and there may be a type check or cast which should be removed also (top of head we used to force cast to fp16).

For the device, overloading .to() is possible but there can be a lot of cases to handle (for instance if blocksparse is part of a wrap which intercepts this call), I'm not sure that it will be much cleaner ? The first FW is much slower with Triton anyway due to the JIT, so I think that this has no measurable perf impact. If .to() can read cleaner then sure, and PR welcome to this branch if you want ?

No problem for the timing, not very timely myself on the topic :)

@blefaudeux blefaudeux mentioned this pull request Apr 19, 2022
10 tasks
@blefaudeux blefaudeux added the enhancement New feature or request label Apr 19, 2022
@blefaudeux blefaudeux linked an issue Apr 19, 2022 that may be closed by this pull request
@blefaudeux blefaudeux marked this pull request as ready for review April 19, 2022 21:24
self.supports_key_padding_mask = True
# The underlying triton op does not support per element attention mask
self.supports_attention_mask = False
self.supports_key_padding_mask = False

def update_mask_type(self, mask: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was just reading the code and realised this is not used anymore, safe to delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! In general I need to give it a second look and clean things up, I was waiting for @colehawkins so that there's no conflict but will do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux Posted in #277. Pending CI, but tests passed locally so I have high hopes.

# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
return torch.float32, 1e-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, that is quite some low precision...

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've left a few comments, but given the size of this PR already, I'd propose that we address them in a separate PR after this one is merged.

tests/test_triton_dropout.py Show resolved Hide resolved
x_ref = (x + b if bias else x).to(y.dtype)
assert not torch.allclose(x_ref, y, rtol=tol)

# Check that the drops are different for every row (could catch broken seeds per row)
y = triton_dropout(x, p=0.5)

print(y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from debugging?

@@ -8,7 +8,7 @@
import torch

# Please update the doc version in docs/source/conf.py as well.
__version__ = "0.0.10"
__version__ = "0.0.11.dev"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good thing to do indeed!

We should remember to remove this during releases though.

Maybe it would be better to split this off into its own PR?

# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the
# initial attention setup
# Delayed triton init, to make sure that we get the right device
if not hasattr(self, "sparse_dot_sdd"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me.

One other option would be to inherit the .cuda() / .to() methods so that they re-create those objects.

The current approach is fine as is because the those objects don't contain learnable parameters, but if that was the case it would mess up with optimizers / distributed.

# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = True
# The underlying triton op does not support per element attention mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in principle a BC-breaking change. Should we bother, or should we just follow what Triton does?

# TODO triton softmax performs an in-place operation
# res = arg0.__sparse_softmax(arg0.__values)
res = arg0.__sparse_softmax(arg0.__values.clone())
res = arg0.__sparse_softmax(arg0.__values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@@ -166,15 +169,22 @@ def dropout(
Optionally add a bias, the computation will be fused.
"""

assert p < 1.0, f"We don't want to drop all the values, most probably {p}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch supports this case, so if we want our dropout to be a drop-in replacement to PyTorch's implementation it would be good to support this as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, incoming update, thanks for the catch !

@blefaudeux blefaudeux changed the title [RFC] 3/3 Triton 2 update [backend] 3/3 Triton 2 update Apr 21, 2022
@blefaudeux
Copy link
Contributor Author

checking right now with a small ViT/cifar training, then landing the whole stack asap

kashif and others added 5 commits April 20, 2022 20:24
author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100
committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700

Move to Triton 2

Author:    Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me>

Tentatively fixing layernorm

- faster all around
- bugfix

better take on sparse tensors, put layout on the correct device
update the pip packages, minor cleanup
…power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>
@blefaudeux blefaudeux changed the base branch from label_attention_properties to conda_ci April 21, 2022 03:24
@blefaudeux blefaudeux merged commit 4ecbec1 into conda_ci Apr 21, 2022
@blefaudeux
Copy link
Contributor Author

I tried to address most comments @fmassa, we can follow up on the version numbering and delayed init in another PR. I just checked with my "classical" ViT/Cifar test, just in case, same accuracy as before when pulling in all the triton layers

blefaudeux added a commit that referenced this pull request Apr 21, 2022
…h combo (#271)

* testing using conda to get the pytorch nightlies and matching cuda

* [fix] Making it explicit whether the attention mechanism supports an attention mask or not (#266)

check the assert

* [backend] 3/3 Triton 2 update (#272)

* parent be72b26
author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100
committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700

Move to Triton 2

Author:    Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me>

Tentatively fixing layernorm

- faster all around
- bugfix

better take on sparse tensors, put layout on the correct device
update the pip packages, minor cleanup

* catering for triton blocksparse being probably more reliable in fp16

* faster layernorm

* Minor blocksparse refactoring, update block size restrictions, relax power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>

* code review, thanks @fmassa !

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com>
Co-authored-by: Cole Hawkins <colehawk@amazon.com>

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com>
Co-authored-by: Cole Hawkins <colehawk@amazon.com>
@blefaudeux blefaudeux deleted the triton-2 branch April 21, 2022 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is bfloat16 supported?
7 participants