Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax tutorial crashes (invalid arith.select) when n_cols is a multiple of 16 but <= 128 #4739

Closed
akeley98 opened this issue Sep 17, 2024 · 3 comments · Fixed by #5161
Closed

Comments

@akeley98
Copy link

akeley98 commented Sep 17, 2024

triton.__version__ is 3.0.0 for me

The tutorial code 02-fused-softmax.py given in https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html fails to compile a kernel during the warmup when n_col is a multiple of 16 that is less than or equal to 128 (i.e., <= 16 * num_warps). Error looks like:

(triton) mantissa@MantissaAmpere:~/junk$ python3 ../Downloads/02-fused-softmax.py 
loc("/home/mantissa/junk/../Downloads/02-fused-softmax.py":97:22): error: 'arith.select' op expected condition type to have the same shape as the result type, expected 'tensor<128xi1, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>>', but got 'tensor<128xi1, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>>'
Traceback (most recent call last):
  File "/home/mantissa/junk/../Downloads/02-fused-softmax.py", line 196, in <module>
    y_triton = softmax(x)
  File "/home/mantissa/junk/../Downloads/02-fused-softmax.py", line 144, in softmax
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/runtime/jit.py", line 764, in warmup
    return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Repro: modify line 195 in 02-fused-softmax.py from

x = torch.randn(1823, 781, device='cuda')

to

x = torch.randn(1823, 80, device='cuda')

and run.

@akeley98
Copy link
Author

akeley98 commented Sep 17, 2024

This seems broken ... the tutorial builds a simple kernel cache using softmax_kernel.warmup where the caching depends only on BLOCK_SIZE. However it seems that warmup actually specializes based on the kernel arguments themselves (not just meta-parameters) and this is where the compiler gets in trouble ... it seems to give a different kernel entirely when y.stride(0) is divisible by 16. So the kernel that's cached depends on the parameters to the first run of the softmax() function (since that determines how the kernel is compiled). So if I write a function

def do_it(n_cols):
    torch.manual_seed(0)
    x = torch.randn(10000, n_cols, device='cuda')
    y_triton = softmax(x)
    y_torch = torch.softmax(x, axis=1)
    assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

then run

do_it(n_cols = 80)

then this crashes like before, but if I do

do_it(n_cols = 79)
do_it(n_cols = 80)

then it works OK, as the kernel correctly compiled for the n_cols = 79 case is successfully re-run for the n_cols = 80 case.

By the way, this also means I can crash the example by running

do_it(n_cols = 512)
do_it(n_cols = 511)

because the n_cols = 512 call specializes the cached kernel for aligned data, which doesn't work correctly for n_cols = 511. I'm seeing

Traceback (most recent call last):
  File "/home/mantissa/junk/tl_softmax.py", line 157, in <module>
    do_it(n_cols = 511)
  File "/home/mantissa/junk/tl_softmax.py", line 154, in do_it
    assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@akeley98
Copy link
Author

I don't know what is going on with the original bug, but I did some investigation on the second issue I found, and it looks like there already is a cache implemented for JITFunction inside runtime/jit.py that specializes based on whether input tensors are aligned or not. This seems to be implemented with compute_spec_key (checks 16-byte alignment) and the sig_and_spec portion of the key used to index self.cache. Since there's a built-in cache, I'm not sure why the tutorial implements its own (broken) caching system.

Mogball added a commit to Mogball/triton that referenced this issue Nov 15, 2024
When the OOB values for a `tt.load` are non-zero, the for loop pipeliner
needs to generate an `arith.select` to mask the loaded values with the
default OOB value. However, if the load memory requires a layout change,
the wrong mask operand was being passed to the `arith.select`, causing a
shape mismatch. The fix is to just use the same mask operand of the
origianl `tt.load` op.

Fixes triton-lang#4739
@Mogball
Copy link
Collaborator

Mogball commented Nov 15, 2024

Hey, thanks for identifying and narrowing down this problem! I have a fix for the first problem identified in #5161 (the compilation failure). I'll dig into trying to fix the tutorial code in a bit -- I also found it confusing when going through the tutorial!

Mogball added a commit to Mogball/triton that referenced this issue Nov 15, 2024
The fused softmax implementation in the tutorial precompiles the kernel
to query the register usage of the kernel, based on the parameters used
to specialize the kernel. On top of this, it implements a simple caching
system for this step based on just the block size.

As noted in triton-lang#4739, this
caching is incorrect, because it's also not keyed on the `num_stages`
constexpr argument or the shapes of the tensors. Since triton already
has its own JIT compilation cache, and this caching bit is not really
relevant to the tutorial, just remove it to get rid of the footgun.
Mogball added a commit that referenced this issue Nov 15, 2024
)

When the OOB values for a `tt.load` are non-zero, the for loop pipeliner
needs to generate an `arith.select` to mask the loaded values with the
default OOB value. However, if the load memory requires a layout change,
the wrong mask operand was being passed to the `arith.select`, causing a
shape mismatch. The fix is to just use the same mask operand of the
origianl `tt.load` op.

Fixes #4739
hmalgewatta pushed a commit to hmalgewatta/triton that referenced this issue Nov 15, 2024
…iton-lang#5161)

When the OOB values for a `tt.load` are non-zero, the for loop pipeliner
needs to generate an `arith.select` to mask the loaded values with the
default OOB value. However, if the load memory requires a layout change,
the wrong mask operand was being passed to the `arith.select`, causing a
shape mismatch. The fix is to just use the same mask operand of the
origianl `tt.load` op.

Fixes triton-lang#4739
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants