-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[backend] 3/3 Triton 2 update #272
Conversation
to peek into the remaining issues: https://app.circleci.com/pipelines/github/facebookresearch/xformers/1395/workflows/b90368b4-eda1-4037-8cd3-a4d138b6e320/jobs/3075 getting there |
@dianaml0 just FYI, investigating the layernorm crash, this works fine with cuda 11.6 / Ampere actually, but I can repro with 11.4 |
some writes where not masked, could be why. I just pushed a small update, works on my machine (tm) |
layernorm fixed, as far as I can see, should be a bit faster for very big sizes also |
@fmassa if you have some cycles at some point, would you mind having a look at the test_sparse_softmax test which does not pass with this branch ? You have a little more context around there (and you could check the changes from this PR, probably a bit rough in this area) |
6c97cca
to
4327d39
Compare
I just improved a bit on the changes in /sparse_tensor, should be ok by now except for the fact that the new blocksparse does not acccept a per pixel mask anymore, so this probably breaks some of this abstraction. |
be72b26
to
8113277
Compare
6369798
to
9554e3f
Compare
return torch.float32, 1e-1 | ||
|
||
# Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU) | ||
torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa this fixed issues that I was seeing with these unit tests on an ampere GPU, which I presume stemmed from the fact that the sparse kernels were fp32 while pytorch defaulted to tf32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow, thanks for spotting this!
One more instance where tf32 is being somewhat harmful. Maybe worth commenting on pytorch/pytorch#67384 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a strange format, range of fp32 but precision of fp16, it's also kind of peculiar that it's really 18bits but named tf32..
# Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations | ||
# TF32 has the precision of fp16 but the range of fp32 | ||
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ | ||
torch.backends.cuda.matmul.allow_tf32 = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa this seems to be a better fit following the switch to triton2, which internally moved all tl.dot() operations to tf32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ptillet, just swapping triton 1.1 for 2.dev meant that this test would not pass anymore, as we discussed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM wrt the tests!
|
||
|
||
def _get_dtype_atol(tensor_type, device: str): | ||
_seed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was to remove some reproducibility issues in between circleci and my machine..
MODE, | ||
trans_a=TRANS_A, | ||
trans_b=TRANS_B, | ||
device=torch.device("cuda"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triton blocksparse op now requires the device to be passed in
|
||
# triton result | ||
op = blocksparse_softmax(layout, BLOCK) | ||
op = blocksparse_softmax(layout, BLOCK, device=torch.device("cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triton blocksparse softmax now requires the device to be passed in
ty = op( | ||
tx, | ||
scale=scale, | ||
key_padding_mask=kp_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triton blocksparse now does not support attn mask or key_padding mask. We can pass a "causal" flag though
@@ -50,7 +50,7 @@ def test_layernorm_parity(shape, amp): | |||
torch.random.manual_seed(0) | |||
X_ = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True) | |||
|
|||
eps = 1e-5 | |||
eps = 1e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1/1e-5 overflows in fp16..
# Properties specific to this attention mechanism | ||
self.supports_attention_mask = True | ||
self.supports_key_padding_mask = True | ||
# The underlying triton op does not support per element attention mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see previous PR which introduced these flags, we can now flip them in this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in principle a BC-breaking change. Should we bother, or should we just follow what Triton does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically we don't have much choice, short of implementing blocksparse ourselves ? Phil arguments were that it was not typically used (short of causal, which is supported), and that the attention mask (additive -> floats) ended up taking a significant amount of space in memory. My take would be that we have a in-house fallback since people can use sparse attention, and else they can stick to the current pip release for some time, so not something that I would do every day (breaking BC), but in that case that was ok ?
# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the | ||
# initial attention setup | ||
# Delayed triton init, to make sure that we get the right device | ||
if not hasattr(self, "sparse_dot_sdd"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triton blocksparse need the correct device to be passed, but it could be that it's not the case at construction time (if constructed on CPU then moved), and it's not possible to just default to cuda:0 (would break many multi-gpu cases). So we defer the construction until after the first input tensor comes in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me.
One other option would be to inherit the .cuda()
/ .to()
methods so that they re-create those objects.
The current approach is fine as is because the those objects don't contain learnable parameters, but if that was the case it would mess up with optimizers / distributed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, that's something we discussed with @colehawkins on another PR, one possible issue for me is that some sharded trainers intercept the .to() calls, so this would silently fail in that case. I think that both takes have issues (the delayed init and the .to() overload), the only clean way out that I can think of is to make this attention take a "device" as a construction argument, and put it to the right place from the beginning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux @fmassa I've been using this with both pytorch lightning and the huggingface trainer, and this method (initialization at first forward) is the cleanest way I found that doesn't break any standard workflows or require model initialization workarounds. If we just take the device at construction this breaks the huggingface trainer "natural approach" for single-node, multi-gpu which (1) create the model, then (2) call .to()
.
One alternative is to initialize with a device at construction, keep that as self.device
, and then check that against query.device
and possibly re-initialize at the forward pass.
Codecov Report
@@ Coverage Diff @@
## label_attention_properties #272 +/- ##
=============================================================
Coverage ? 92.69%
=============================================================
Files ? 61
Lines ? 3393
Branches ? 0
=============================================================
Hits ? 3145
Misses ? 248
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
I think that there's some margin in terms of speed around the fused linear layers, the matmul triton op improved a lot in the last few months and fused linear could probably benefit from some of it. Better keeping it for another PR though |
@blefaudeux larger block sizes (at least up to 128) are supported by triton v2. Also, is any type of recompile option possible for the blocksparse attention? I ran into this issue using tritonv2 blocksparse attention in a distributed environment (specifically huggingface trainer). Since the ops are device-specific, they need to be recreated or there is a device mismatch. It's not too hard to work around with a delayed device-specific model initialization, but I think there are potentially smoother work arounds by inheriting the Happy to submit both in a PR to either branch, much more timely than last time (1-2 days). |
Oh sure for the block size, I didn't know, and there may be a type check or cast which should be removed also (top of head we used to force cast to fp16). For the device, overloading .to() is possible but there can be a lot of cases to handle (for instance if blocksparse is part of a wrap which intercepts this call), I'm not sure that it will be much cleaner ? The first FW is much slower with Triton anyway due to the JIT, so I think that this has no measurable perf impact. If .to() can read cleaner then sure, and PR welcome to this branch if you want ? No problem for the timing, not very timely myself on the topic :) |
self.supports_key_padding_mask = True | ||
# The underlying triton op does not support per element attention mask | ||
self.supports_attention_mask = False | ||
self.supports_key_padding_mask = False | ||
|
||
def update_mask_type(self, mask: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was just reading the code and realised this is not used anymore, safe to delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch ! In general I need to give it a second look and clean things up, I was waiting for @colehawkins so that there's no conflict but will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux Posted in #277. Pending CI, but tests passed locally so I have high hopes.
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
return torch.float32, 1e-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, that is quite some low precision...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I've left a few comments, but given the size of this PR already, I'd propose that we address them in a separate PR after this one is merged.
tests/test_triton_dropout.py
Outdated
x_ref = (x + b if bias else x).to(y.dtype) | ||
assert not torch.allclose(x_ref, y, rtol=tol) | ||
|
||
# Check that the drops are different for every row (could catch broken seeds per row) | ||
y = triton_dropout(x, p=0.5) | ||
|
||
print(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover from debugging?
xformers/__init__.py
Outdated
@@ -8,7 +8,7 @@ | |||
import torch | |||
|
|||
# Please update the doc version in docs/source/conf.py as well. | |||
__version__ = "0.0.10" | |||
__version__ = "0.0.11.dev" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good thing to do indeed!
We should remember to remove this during releases though.
Maybe it would be better to split this off into its own PR?
# If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the | ||
# initial attention setup | ||
# Delayed triton init, to make sure that we get the right device | ||
if not hasattr(self, "sparse_dot_sdd"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me.
One other option would be to inherit the .cuda()
/ .to()
methods so that they re-create those objects.
The current approach is fine as is because the those objects don't contain learnable parameters, but if that was the case it would mess up with optimizers / distributed.
# Properties specific to this attention mechanism | ||
self.supports_attention_mask = True | ||
self.supports_key_padding_mask = True | ||
# The underlying triton op does not support per element attention mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in principle a BC-breaking change. Should we bother, or should we just follow what Triton does?
# TODO triton softmax performs an in-place operation | ||
# res = arg0.__sparse_softmax(arg0.__values) | ||
res = arg0.__sparse_softmax(arg0.__values.clone()) | ||
res = arg0.__sparse_softmax(arg0.__values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
xformers/triton/dropout.py
Outdated
@@ -166,15 +169,22 @@ def dropout( | |||
Optionally add a bias, the computation will be fused. | |||
""" | |||
|
|||
assert p < 1.0, f"We don't want to drop all the values, most probably {p}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch supports this case, so if we want our dropout to be a drop-in replacement to PyTorch's implementation it would be good to support this as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, incoming update, thanks for the catch !
checking right now with a small ViT/cifar training, then landing the whole stack asap |
author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100 committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700 Move to Triton 2 Author: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> Tentatively fixing layernorm - faster all around - bugfix better take on sparse tensors, put layout on the correct device update the pip packages, minor cleanup
…power of two constraint (#277) * Relax device size restrictions * Refactor device creation and run all tests * linting Co-authored-by: Cole Hawkins <colehawk@amazon.com>
I tried to address most comments @fmassa, we can follow up on the version numbering and delayed init in another PR. I just checked with my "classical" ViT/Cifar test, just in case, same accuracy as before when pulling in all the triton layers |
…h combo (#271) * testing using conda to get the pytorch nightlies and matching cuda * [fix] Making it explicit whether the attention mechanism supports an attention mask or not (#266) check the assert * [backend] 3/3 Triton 2 update (#272) * parent be72b26 author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100 committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700 Move to Triton 2 Author: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> Tentatively fixing layernorm - faster all around - bugfix better take on sparse tensors, put layout on the correct device update the pip packages, minor cleanup * catering for triton blocksparse being probably more reliable in fp16 * faster layernorm * Minor blocksparse refactoring, update block size restrictions, relax power of two constraint (#277) * Relax device size restrictions * Refactor device creation and run all tests * linting Co-authored-by: Cole Hawkins <colehawk@amazon.com> * code review, thanks @fmassa ! Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com> Co-authored-by: Cole Hawkins <colehawk@amazon.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com> Co-authored-by: Cole Hawkins <colehawk@amazon.com>
What does this PR do?
Push things a little forward with triton 2. My thinking was to try to land all 3 PRs in one go when this last one is green
Happy to update this PR on that front (and others), this is something quickly wrapped together to try to unlock some other PRs (like #263)
TODOs:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.