-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Softmax tutorial crashes (invalid arith.select) when n_cols is a multiple of 16 but <= 128 #4739
Comments
This seems broken ... the tutorial builds a simple kernel cache using
then run
then this crashes like before, but if I do
then it works OK, as the kernel correctly compiled for the By the way, this also means I can crash the example by running
because the
|
I don't know what is going on with the original bug, but I did some investigation on the second issue I found, and it looks like there already is a cache implemented for JITFunction inside |
When the OOB values for a `tt.load` are non-zero, the for loop pipeliner needs to generate an `arith.select` to mask the loaded values with the default OOB value. However, if the load memory requires a layout change, the wrong mask operand was being passed to the `arith.select`, causing a shape mismatch. The fix is to just use the same mask operand of the origianl `tt.load` op. Fixes triton-lang#4739
Hey, thanks for identifying and narrowing down this problem! I have a fix for the first problem identified in #5161 (the compilation failure). I'll dig into trying to fix the tutorial code in a bit -- I also found it confusing when going through the tutorial! |
The fused softmax implementation in the tutorial precompiles the kernel to query the register usage of the kernel, based on the parameters used to specialize the kernel. On top of this, it implements a simple caching system for this step based on just the block size. As noted in triton-lang#4739, this caching is incorrect, because it's also not keyed on the `num_stages` constexpr argument or the shapes of the tensors. Since triton already has its own JIT compilation cache, and this caching bit is not really relevant to the tutorial, just remove it to get rid of the footgun.
) When the OOB values for a `tt.load` are non-zero, the for loop pipeliner needs to generate an `arith.select` to mask the loaded values with the default OOB value. However, if the load memory requires a layout change, the wrong mask operand was being passed to the `arith.select`, causing a shape mismatch. The fix is to just use the same mask operand of the origianl `tt.load` op. Fixes #4739
…iton-lang#5161) When the OOB values for a `tt.load` are non-zero, the for loop pipeliner needs to generate an `arith.select` to mask the loaded values with the default OOB value. However, if the load memory requires a layout change, the wrong mask operand was being passed to the `arith.select`, causing a shape mismatch. The fix is to just use the same mask operand of the origianl `tt.load` op. Fixes triton-lang#4739
triton.__version__
is 3.0.0 for meThe tutorial code 02-fused-softmax.py given in https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html fails to compile a kernel during the warmup when
n_col
is a multiple of 16 that is less than or equal to 128 (i.e., <= 16 * num_warps). Error looks like:Repro: modify line 195 in 02-fused-softmax.py from
to
and run.
The text was updated successfully, but these errors were encountered: