diff --git a/BENCHMARKS.md b/BENCHMARKS.md index 8724ca3526..2452ba633d 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -39,8 +39,7 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py ### Fused softmax -You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10. - +You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12. ![Softmax throughput in fp16 - inference](docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png) @@ -52,8 +51,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/ ### Fused linear layer -You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10. - +You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12. ![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png) @@ -77,7 +75,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/ ### Fused layer norm -You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10. +You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12. ![Fused layer norm throughput in fp16 - inference](docs/plots/layer_norm/LayerNorm_FW_torch.float16.png) @@ -89,7 +87,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/ ### Fused dropout + bias + activation -You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10. +You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12. ![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd0f2e13c0..af4f1fbe41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Mem efficient attention, FW pass [#267] - MHA benchmark - MLP benchmark +- Move all triton kernels to triton v2 [#272] ## [0.0.10] - 2022-03-14 ### Fixed diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png index b6daad60b1..4b01969654 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png index a6ce46bd11..4d969d9602 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png index 53551bfb50..dc93c1a99b 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png index c38a9920ee..41e95e9656 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png index 3fab1aa5c9..33e6949735 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png index 87aaccf6cb..a992f5b973 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png index 77a3ca064a..6219071691 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png index c9aa6fe7f6..cb092efbac 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png index f0ebe52763..12b9e19fb4 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png index 55b709cc3c..681a8bf7d3 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png index 6ff747c84e..fcc4480db5 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png index 18ddbe839d..8889349a46 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png index c180cd479c..bb1f3168ca 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png index 6cd37c7822..ba8109f341 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png index baa5b1a640..f37a780e6b 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png index 3764396b94..4ad91ec132 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png index 800e2f4a28..a78c14e75f 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png index 5017cf7266..cd635ce765 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png index f9b9ad27ef..f5b0633a12 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png index d4e4e06d38..3678c34b18 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png index 169e0304aa..8dd7ea07a9 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png index c161c0e121..f5a7465510 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png index ed60dd3363..4b60ff3e91 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png index 65522a30e1..d7e3f28351 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png index cd3b34fdc1..c512217ed7 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png index 2d044d81c2..b1f1cea029 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png index a215a4b59f..16eade9d04 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png index df7d4657a3..4cf8daeb30 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png index 39954721d7..4529a4c306 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png index b962f74757..a46a22875a 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png index 272b27428b..709be166af 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png index 5ca77992d8..153b8a3371 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_none.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png index d2b3c12eef..4dae159033 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png differ diff --git a/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png b/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png index 64d0b564c2..f110c42ed9 100644 Binary files a/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png and b/docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png index bdd15c9e7f..2bb20cd112 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png index 640e143da3..17600ed1e5 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png index a790cc4a29..4ccf849881 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png differ diff --git a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png index f076a5a4a2..61cb30654b 100644 Binary files a/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png and b/docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png index 154a3e530f..297ec3c667 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png and b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png index 38df22e39a..256b8d9f37 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png and b/docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png b/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png index be9c73df59..4000421a80 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png and b/docs/plots/layer_norm/LayerNorm_FW_torch.float16.png differ diff --git a/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png b/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png index 2704aa4e2e..6012e52e74 100644 Binary files a/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png and b/docs/plots/layer_norm/LayerNorm_FW_torch.float32.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 867d61cf09..3da4d8f386 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,7 +37,6 @@ # The full version, including alpha/beta/rc tags release = "0.0.10" - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be diff --git a/examples/microGPT.py b/examples/microGPT.py index 2a9d9b38a2..048ae962a6 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -21,7 +21,7 @@ class GPT(pl.LightningModule): - """ the full GPT language model, with a context size of block_size """ + """the full GPT language model, with a context size of block_size""" def __init__( self, diff --git a/experimental/ragged_inference/garbage_pad_ragged_acts.py b/experimental/ragged_inference/garbage_pad_ragged_acts.py index 2238521107..7199751559 100644 --- a/experimental/ragged_inference/garbage_pad_ragged_acts.py +++ b/experimental/ragged_inference/garbage_pad_ragged_acts.py @@ -18,9 +18,9 @@ def garbage_pad_ragged_acts_kernel( ragged_acts_offset_per_seq_ptr, n_ctx_per_seq_ptr, padded_acts_ptr, - **meta, # Optional meta-parameters for the kernel + BLOCK_SIZE: tl.constexpr, # How many inputs each program should process + n_ctx_max: tl.constexpr, ): - BLOCK_SIZE = meta["d_model"] # How many inputs each program should process # There are multiple 'program's processing different data. We identify which program # we are here @@ -47,7 +47,6 @@ def garbage_pad_ragged_acts_kernel( acts = tl.load(ragged_acts_ptr + ragged_acts_offsets, mask=ctx_idx_too_large_mask) # Calculate the offsets for the padded acts - n_ctx_max = meta["n_ctx_max"] padded_acts_offset = n_ctx_max * seq_idx * BLOCK_SIZE # Write things back, again masking out the sections that would be garbage @@ -153,7 +152,7 @@ def triton_to_garbage_padded(self) -> torch.Tensor: torch.tensor(ragged_acts_offset_per_seq, device="cuda"), torch.tensor(self.n_ctx_per_seq, device="cuda"), padded_acts, - d_model=d_model, + BLOCK_SIZE=d_model, n_ctx_max=n_ctx_max, ) return padded_acts diff --git a/requirements-benchmark.txt b/requirements-benchmark.txt index 4522a1ac3d..c82da257d5 100644 --- a/requirements-benchmark.txt +++ b/requirements-benchmark.txt @@ -7,5 +7,4 @@ scikit-learn == 0.24.1 tqdm == 4.59.0 pandas == 1.2.4 seaborn == 0.11.1 -triton == 1.1.2.dev20220106 pytorch-lightning >= 1.3 diff --git a/requirements-test.txt b/requirements-test.txt index 850091a153..d8157d6e30 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -24,3 +24,6 @@ hydra-core >= 1.1 # Dependency for Mixture of Experts fairscale >= 0.4.5 + +# Dependency for fused layers, optional +triton == 2.0.0.dev20220403 diff --git a/tests/test_sparse_tensors.py b/tests/test_sparse_tensors.py index 9b3d819ba9..e1a60917bd 100644 --- a/tests/test_sparse_tensors.py +++ b/tests/test_sparse_tensors.py @@ -19,7 +19,7 @@ def _create_blocksparse_tensor( device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32 ): - layout = torch.randint(2, (C, H // block_size, W // block_size)) + layout = torch.randint(2, (C, H // block_size, W // block_size), device=device) layout[:, :, 0] = 1 layout[:, 0, :] = 1 values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to( @@ -56,6 +56,29 @@ def _create_tensor(tensor_type, device, dtype, shape, sparsity): ) +def _seed(): + torch.random.manual_seed(42) + torch.cuda.manual_seed_all(42) + + +def _get_dtype_atol(tensor_type, device: str): + _seed() + + if tensor_type == BlockSparseTensor and "cuda" in device: + # Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations + # TF32 has the precision of fp16 but the range of fp32 + # See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + return torch.float32, 1e-1 + + # Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + return torch.float32, 1e-5 + + @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("func", [torch.add, torch.mul]) def test_sparse_binary_ops(func, device): @@ -83,6 +106,7 @@ def test_sparse_binary_ops(func, device): def test_masked_matmul(tensor_type, device): N, C, H, W, L = 8, 2, 64, 64, 32 sparsity = 0.7 + dtype, atol = _get_dtype_atol(tensor_type, device) shape0 = (N, C, H, W) shape1 = (N, C, H, L) @@ -98,8 +122,8 @@ def test_masked_matmul(tensor_type, device): ) mask = mask_sparse.to_dense() - a = torch.randn(shape1, device=device) - b = torch.randn(shape2, device=device) + a = torch.randn(shape1, device=device, dtype=dtype) + b = torch.randn(shape2, device=device, dtype=dtype) aa = a.clone() bb = b.clone() @@ -119,24 +143,23 @@ def test_masked_matmul(tensor_type, device): res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf"))) assert res.dtype == res_gt.dtype - assert torch.allclose(res_dense, res_gt, atol=5e-6) + assert torch.allclose(res_dense, res_gt, atol=atol) # try to workaround non-contiguous issues with triton for now res_gt.backward(torch.ones_like(res_gt)) res.values().backward(torch.ones_like(res.values())) - # TODO: this is not passing for BlockSparse!!! - if tensor_type != BlockSparseTensor: - assert torch.allclose(a.grad, aa.grad, atol=5e-6) - assert torch.allclose(b.grad, bb.grad, atol=5e-6) + + assert torch.allclose(a.grad, aa.grad, atol=atol) + assert torch.allclose(b.grad, bb.grad, atol=atol) @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_bmm(tensor_type, device): N, C, H, W, L = 8, 2, 64, 64, 32 - dtype = torch.float32 - sparsity = 0.8 + dtype, atol = _get_dtype_atol(tensor_type, device) + sparsity = 0.8 shape0 = (N, C, H, W) shape1 = (N, C, W, L) @@ -153,7 +176,7 @@ def test_bmm(tensor_type, device): a_sparse.requires_grad_(True) a.requires_grad_(True) - b = torch.randn(shape1, device=device) + b = torch.randn(shape1, device=device, dtype=dtype) b2 = b.clone() b.requires_grad_(True) @@ -163,7 +186,9 @@ def test_bmm(tensor_type, device): res = a_sparse @ b2 assert res.dtype == res_gt.dtype - assert torch.allclose(res, res_gt, atol=1e-5) + assert torch.allclose( + res, res_gt, atol=atol + ), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}" res_gt.sum().backward() res.sum().backward() @@ -171,15 +196,18 @@ def test_bmm(tensor_type, device): a_grad = a.grad.clone().detach() a_grad[~mask] = 0 - assert torch.allclose(b.grad, b2.grad, atol=1e-5) - assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-5) + assert torch.allclose(b.grad, b2.grad, atol=atol) + assert torch.allclose( + a_grad, a_sparse.grad.to_dense(), atol=atol + ), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}" @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_sparse_softmax(tensor_type, device): N, C, H, W = 8, 2, 64, 64 - dtype = torch.float32 + dtype, atol = _get_dtype_atol(tensor_type, device) + sparsity = 0.8 shape0 = (N, C, H, W) @@ -203,7 +231,9 @@ def test_sparse_softmax(tensor_type, device): res = res_sparse.to_dense() assert res.dtype == res_gt.dtype - assert torch.allclose(res, res_gt) + assert torch.allclose( + res, res_gt, atol=atol + ), f"{torch.max(torch.abs(res- res_gt))}" # WARNING: gradients are modified in-place! res_sparse.values().backward(torch.ones_like(res_sparse.values())) @@ -212,7 +242,9 @@ def test_sparse_softmax(tensor_type, device): a_grad = a.grad.clone() a_grad[~mask] = 0 - assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-6) + assert torch.allclose( + a_grad, a_sparse.grad.to_dense(), atol=atol + ), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}" @pytest.mark.parametrize("tensor_type", _tensor_types) diff --git a/tests/test_triton_basics.py b/tests/test_triton_basics.py index 69bb472a2f..0c75232482 100644 --- a/tests/test_triton_basics.py +++ b/tests/test_triton_basics.py @@ -36,7 +36,7 @@ if _triton_available: @triton.jit - def k_mean(X, Mean, Var, stride, N, **META): + def k_mean(X, Mean, Var, stride, N, BLOCK_SIZE_N: tl.constexpr): # fmt: on """ Fused layernorm kernel over a 3d tensor. @@ -47,7 +47,7 @@ def k_mean(X, Mean, Var, stride, N, **META): """ row = tl.program_id(0) - cols = tl.arange(0, META["BLOCK_SIZE_N"]) + cols = tl.arange(0, BLOCK_SIZE_N) # Move to this row x_ptrs = X + row * stride + cols diff --git a/tests/test_triton_blocksparse.py b/tests/test_triton_blocksparse.py index 2c09f3b93a..a0ffe78595 100644 --- a/tests/test_triton_blocksparse.py +++ b/tests/test_triton_blocksparse.py @@ -72,7 +72,14 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) # triton result - op = blocksparse_matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) + op = blocksparse_matmul( + layout, + BLOCK, + MODE, + trans_a=TRANS_A, + trans_b=TRANS_B, + device=torch.device("cuda"), + ) ra = block_sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = block_sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rc = triton.testing.catch_oor(lambda: op(ra, rb), pytest) @@ -91,7 +98,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") -@pytest.mark.parametrize("BLOCK", [32]) +@pytest.mark.parametrize("BLOCK", [32, 128]) @pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) def test_softmax(BLOCK, WIDTH, DTYPE): @@ -103,34 +110,15 @@ def test_softmax(BLOCK, WIDTH, DTYPE): # create inputs layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") - at_mask = torch.randint( - low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda" - ) - kp_mask = torch.randint( - low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda" - ) - kp_mask[kp_mask == 1.0] = float("-inf") # triton result - op = blocksparse_softmax(layout, BLOCK) + op = blocksparse_softmax(layout, BLOCK, device=torch.device("cuda")) tx = block_sparsify_tensor(x, layout, BLOCK) - ty = op( - tx, - scale=scale, - key_padding_mask=kp_mask, - key_padding_mask_mode="add", - attn_mask=at_mask.to(DTYPE), - attn_mask_mode="mul", - ) + ty = op(tx, scale=scale) # torch result rx = triton.testing.mask_tensor(x, layout, BLOCK, value=float("-inf")) - if at_mask is not None: - # broadcast at_mask to the same shape as rx - M = at_mask[None, None, :, :] + torch.zeros_like(rx) - rx[M == 0] = float("-inf") - if kp_mask is not None: - rx += kp_mask[:, None, None, :] + ry = torch.softmax(rx * scale, -1) ry = block_sparsify_tensor(ry, layout, BLOCK) @@ -139,12 +127,12 @@ def test_softmax(BLOCK, WIDTH, DTYPE): @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") -@pytest.mark.parametrize("block", [32, 43]) # 16, 32, +@pytest.mark.parametrize("block", [32, 43, 128]) # 16, 32, def test_attention_fwd_bwd( block, input_scale=1.0, scale=1 / 8.0, - n_ctx=256, + n_ctx=384, dtype=torch.float16, batch_size=2, n_heads=2, @@ -158,34 +146,26 @@ def test_attention_fwd_bwd( .cuda() for _ in range(3) ] - attn_mask = torch.tril( - torch.ones( - [n_ctx, n_ctx], - device="cuda", - dtype=dtype, - ), - diagonal=0, - ) def loss_fn(x): return (x ** 2).mean() # Triton: n_blocks = n_ctx // block - layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + layout = torch.tril( + torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long), diagonal=-1 + ) query, key, value = [x.clone() for x in qkvs] query.retain_grad() key.retain_grad() value.retain_grad() - if block not in [16, 32, 64]: + if block not in [16, 32, 64, 128]: # Check that unsupported dimensions are caught with pytest.raises(AssertionError): _ = BlockSparseAttention(layout, block) else: block_sparse_attention = BlockSparseAttention(layout, block) - attn_out = block_sparse_attention( - att_mask=attn_mask, q=query, k=key, v=value, scale=scale - ) + attn_out = block_sparse_attention(q=query, k=key, v=value, scale=scale) # ad hoc loss loss = loss_fn(attn_out) @@ -195,12 +175,10 @@ def loss_fn(x): # Torch version: torch_q, torch_k, torch_v = [x.clone() for x in qkvs] torch_q = torch_q / math.sqrt(head_dim) - attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) torch_q.retain_grad() torch_k.retain_grad() torch_v.retain_grad() scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) - scores = scores + attn_mask probs = torch.softmax(scores, dim=-1) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index c34a757572..13265b6d99 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -88,9 +88,14 @@ def test_dropout(shape, amp, bias, p): x_ref = (x + b if bias else x).to(y.dtype) assert torch.allclose(x_ref, y, rtol=tol), f"{x[x>y]}" - # Check that 1 means dropout for sure + # Check that 1 means drop all y = triton_dropout(x, p=1, bias=b) x_ref = (x + b if bias else x).to(y.dtype) + assert torch.allclose(torch.zeros_like(y), y, rtol=tol) + + # Check that .99 means probably dropout + y = triton_dropout(x, p=0.99, bias=b) + x_ref = (x + b if bias else x).to(y.dtype) assert not torch.allclose(x_ref, y, rtol=tol) # Check that the drops are different for every row (could catch broken seeds per row) diff --git a/tests/test_triton_fused_linear.py b/tests/test_triton_fused_linear.py index dadb254b5a..a7789824d5 100644 --- a/tests/test_triton_fused_linear.py +++ b/tests/test_triton_fused_linear.py @@ -38,7 +38,7 @@ "dtype", [torch.float32] ) # Triton use tensor cores, which return slightly different results to pytorch mm def test_fused_matmul(shape, dtype): - """ Check that the matrix multiply kernel and Pytorch's give the same results""" + """Check that the matrix multiply kernel and Pytorch's give the same results""" torch.random.manual_seed(0) # Raw fused matrix multiply first, to catch gross errors diff --git a/tests/test_triton_layernorm.py b/tests/test_triton_layernorm.py index 0af7bbe82a..6de484179d 100644 --- a/tests/test_triton_layernorm.py +++ b/tests/test_triton_layernorm.py @@ -50,7 +50,7 @@ def test_layernorm_parity(shape, amp): torch.random.manual_seed(0) X_ = torch.normal(0, 1, size=shape, device="cuda", requires_grad=True) - eps = 1e-5 + eps = 1e-4 # Initialize the two layers, weights are 1 and 0 by default, no randomness torch_layernorm = torch.nn.LayerNorm(X.shape[-1], eps=eps).to("cuda") diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 5d8fd263ec..4598bc20ab 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -26,7 +26,7 @@ def pretty_print(results, title, units): - """ Printout the contents of a dict as a human-readable and Markdown compatible array""" + """Printout the contents of a dict as a human-readable and Markdown compatible array""" print(title) header = " Units: {:<45}".format(units) print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index e0f28e9446..12094fdd57 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -7,7 +7,6 @@ import logging import math from dataclasses import dataclass -from typing import Optional import torch @@ -21,7 +20,6 @@ from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore - from xformers.triton.softmax import MaskType from xformers.triton.utils import gpu_capabilities_older_than_70 # Blocksparse requires Tensor cores @@ -66,6 +64,7 @@ def __init__( block_size: int = 16, dropout: float = 0.0, num_heads: int = 1, # optional, used to adapt the layout if in need + causal: bool = False, *args, **kwargs, ): @@ -83,15 +82,38 @@ def __init__( 16, 32, 64, - ), "Only block sizes in [16, 32, 64] are supported" + 128, + ), "Only block sizes in [16, 32, 64, 128] are supported" super().__init__() + + self.causal = causal + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) # Pure blocksparse data self.layout = layout self.block_size = block_size + # make sure that the head dimension is not folded down with the batch + self.requires_head_dimension = True + + # key padding mask and attention mask must be passed in separately + self.requires_same_k_q_dimensions = True + + # The underlying triton op does not support per element attention mask + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def update_mask_type(self, mask: torch.Tensor): + global _mask_type_warning + if _mask_type_warning: + logging.warning( + "Mask has to be additive. Fixing that but this slows things down" + ) + mask = bool_mask_to_additive(mask) + + def create_triton_kernels(self, device): # blocksparse operators self.sparse_dot_sdd = blocksparse_matmul( self.layout, @@ -99,67 +121,47 @@ def __init__( "sdd", trans_a=False, trans_b=True, + device=device, ) + self.sparse_dot_dsd = blocksparse_matmul( self.layout, self.block_size, "dsd", trans_a=False, trans_b=False, + device=device, ) - self.sparse_softmax = blocksparse_softmax(self.layout, self.block_size) - - # make sure that the head dimension is not folded down with the batch - self.requires_head_dimension = True - - # key padding mask and attention mask must be passed in separately - self.requires_separate_masks = True - self.requires_same_k_q_dimensions = True - - # Properties specific to this attention mechanism - self.supports_attention_mask = True - self.supports_key_padding_mask = True - def update_mask_type(self, mask: torch.Tensor): - global _mask_type_warning - if _mask_type_warning: - logging.warning( - "Mask has to be additive. Fixing that but this slows things down" - ) - mask = bool_mask_to_additive(mask) + self.sparse_softmax = blocksparse_softmax( + self.layout, + self.block_size, + device=device, + ) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, - key_padding_mask: Optional[torch.Tensor] = None, scale: float = 1.0, *args, **kwargs, ) -> torch.Tensor: - r""" - att_mask A 2D attention mask. The dtype must be the same as q. An additive mask is expected, - meaning float values using "-inf" to mask values. - key_padding_mask A mask with size (batch size x sequence length). The dtype must be the same as q. - An additive mask is expected, meaning float values using "-inf" to mask values - """ + assert ( + "att_mask" not in kwargs.keys() and "att_mask" not in args + ), "This attention does not support an attention mask, but you can specify causality." - # NOTE: - # The attention mask will be taken into account when computing the softmax - # meaning that non-masked values which are present in the initial blocksparse layout will be computed. - # If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the - # initial attention setup + r""" + A thin wrap around the Triton blockparse attention operation - if att_mask is not None and att_mask.dtype == torch.bool: - self.update_mask_type(att_mask) - if key_padding_mask is not None and key_padding_mask.dtype == torch.bool: - self.update_mask_type(key_padding_mask) + .. note: Per element attention mask is not supported, but you can specify causality + """ - assert ( - att_mask is None or att_mask.dim() == 2 - ), "The attention mask is constant across heads, expected dimensions are [seq x seq]" + # Delayed triton init, to make sure that we get the right device + # Infer device from query + if not hasattr(self, "sparse_dot_sdd"): + self.create_triton_kernels(q.device) assert ( q.shape[-2] == k.shape[-2] @@ -172,22 +174,16 @@ def forward( k.shape[-2] == self.layout.shape[-2] * self.block_size ), "Actual sequence size and layout are inconsistent" - assert math.log( - q.shape[-2], 2 - ).is_integer(), ( - "For now blocksparse only works on power-of-two sequence lengths" + assert ( + q.shape[-2] % self.block_size + ) == 0, "Sequence length {} must be a multiple of block size {}".format( + q.shape[-2], self.block_size ) # Blocksparse only works on fp16 q_dtype = q.dtype q, k, v = q.half(), k.half(), v.half() - if att_mask is not None: - att_mask = att_mask.half() - - if key_padding_mask is not None: - key_padding_mask = key_padding_mask.half() - # Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S) # When the computations are block sparse, the matrix types change along the way: # - (sparse) attention matrix = (dense) Kt * (dense) Q @@ -196,12 +192,7 @@ def forward( # - softmax on the sparse attention matrix sparse_att_mat = self.sparse_softmax( - sparse_att_mat, - scale=scale, - key_padding_mask=key_padding_mask, - attn_mask=att_mask, - key_padding_mask_mode=MaskType.ADD, - attn_mask_mode=MaskType.ADD, + sparse_att_mat, scale=scale, is_causal=self.causal ) sparse_att_mat = self.attn_drop(sparse_att_mat) diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py index 6b23d38d0b..af6f20b599 100644 --- a/xformers/components/attention/linformer.py +++ b/xformers/components/attention/linformer.py @@ -47,7 +47,7 @@ def __init__( # kq need to have the same dimension self.requires_same_k_q_dimensions = True - # Properties specific to this attention mechanism + # This attention does not support attention masks self.supports_attention_mask = False def forward( diff --git a/xformers/sparse/blocksparse_tensor.py b/xformers/sparse/blocksparse_tensor.py index 840dd7a4fe..7117520a19 100644 --- a/xformers/sparse/blocksparse_tensor.py +++ b/xformers/sparse/blocksparse_tensor.py @@ -107,6 +107,9 @@ def __new__(cls, values, layout): def __init__(self, values, layout): assert values.shape[-2] == values.shape[-1] + assert ( + values.device == layout.device + ), "Both values and layout need to reside on the same device" block_size = values.shape[-1] # TODO: make this check conditioned on the use of Triton assert block_size >= 16, "Minimum block size is 16, for now at least" @@ -125,13 +128,13 @@ def __init__(self, values, layout): def _initialize_triton_ops(self): block_size = self.__values.shape[-1] - self.__sparse_dot_sdd = blocksparse_matmul( self.__layout, block_size, "sdd", trans_a=False, trans_b=True, + device=self.__layout.device, ) self.__sparse_dot_dsd = blocksparse_matmul( self.__layout, @@ -139,8 +142,11 @@ def _initialize_triton_ops(self): "dsd", trans_a=False, trans_b=False, + device=self.__layout.device, + ) + self.__sparse_softmax = blocksparse_softmax( + self.__layout, block_size, device=self.__layout.device ) - self.__sparse_softmax = blocksparse_softmax(self.__layout, block_size) def __repr__(self): return f"block_sparse_tensor(shape={self.shape}, values={self.__values})" @@ -195,9 +201,7 @@ def _softmax(cls, arg0, dim): if not (dim == -1 or dim == 2): return NotImplemented if _can_use_triton(arg0): - # TODO triton softmax performs an in-place operation - # res = arg0.__sparse_softmax(arg0.__values) - res = arg0.__sparse_softmax(arg0.__values.clone()) + res = arg0.__sparse_softmax(arg0.__values) else: res = _softmax(arg0.__layout, arg0.__values) return cls._wrap(res, arg0) diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index cdb958d75c..f14099d5ad 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -36,6 +36,7 @@ def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias): M, N = x_.shape assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N) + assert p > 0.0 def grid(meta): # NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for @@ -53,9 +54,11 @@ def grid(meta): seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32) # fmt: off + bias_ptr = bias if bias is not None else x_ # Possibly not being used + k_dropout_fw[grid]( y, x_, - bias if bias is not None else x_, + bias_ptr, seeds, y.stride(0), M, N, @@ -166,15 +169,25 @@ def dropout( Optionally add a bias, the computation will be fused. """ + assert p <= 1.0 and p >= 0.0 + + if p == 1.0: + return torch.zeros_like(x) + # Micro optim, skip dropout - if p == 0.0 and activation is None: - return x + bias if bias is not None else x + if p == 0.0: + x = x + bias if bias is not None else x + if activation is not None: + activation_fn = build_activation(activation) + return activation_fn(x) + return x + # The normal triton enabled codepath act_kernel = get_triton_activation_kernel(activation) act_grad_kernel = get_triton_activation_bwd_kernel(activation) return _dropout.apply( x, - p, + float(p), bias, act_kernel, act_grad_kernel, @@ -190,7 +203,13 @@ def __init__( activation: Optional[Activation] = None, ) -> None: super().__init__() - self.p = p + + self.p = float(p) + + assert ( + self.p < 1.0 + ), f"We don't want to drop all the values, most probably p={p} is not properly set" + self.activation_type = activation self.bias = ( torch.zeros(bias_shape, requires_grad=True) @@ -213,10 +232,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: perf_check = x.shape[-1] > 512 # Catch a non-cuda setup, fallback to pytorch - if not x.is_cuda or not perf_check: + if not x.is_cuda or not perf_check or p == 0.0: x = x + self.bias if self.bias is not None else x x = self.pytorch_activation(x) - return torch.nn.functional.dropout(x, p) + return torch.nn.functional.dropout(x, p) if p > 0.0 else x # The normal, Triton-backed path return _dropout.apply( diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index d39d66022b..24f4993fb6 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -20,18 +20,18 @@ @triton.jit -def _get_4_bin_masks(seed, rand_offsets, p): - seed = tl.load(seed) +def _get_4_bin_masks(seed_ptr, rand_offsets, p): + seed = tl.load(seed_ptr) rand1, rand2, rand3, rand4 = tl.randint4x(seed, rand_offsets) # binarize masks, save registers - # NOTE: We keep the random numbers as is there (integers over int32), + # NOTE: We keep the random numbers as is there (integers over uint32), # and convert the threshold instead, for speed - # The initial distribution is -2**31 / 2**31 -1 + # The initial distribution is over 2**32 -1 # and our float threshold is in between [0, 1] # The full computation is: `start_point + full range * p` - threshold = (-2147483648.0 + 4294967295.0 * p).to(tl.int32) + threshold = (4294967296.0 * p).to(tl.int32) rand_mask1 = rand1 > threshold rand_mask2 = rand2 > threshold rand_mask3 = rand3 > threshold @@ -44,18 +44,44 @@ def _get_4_bin_masks(seed, rand_offsets, p): def _random_prune_and_scale(x, rand_mask, p, p_scale): zero = 0.0 - if p > 0.0: - # generate all the random numbers for the block at once, then reshape - keep = tl.reshape(rand_mask, x.shape) - - # prune and normalize in one go - x = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype)) + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, x.shape) + # prune and normalize in one go + x = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype)) return x +@triton.jit +def tile_random_drop( + x_ptrs, + y_ptrs, + block_mask, + use_bias, + bias, + rand_mask, + p, + p_scale, + ACTIVATION, +): + x = tl.load(x_ptrs, mask=block_mask, other=0.0) + + # optionally apply a fused bias + if use_bias: + x += bias + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION: + x = ACTIVATION(x) + + # randomly prune (and scale) the resulting buffer, possibly a no-op + output = _random_prune_and_scale(x, rand_mask, p, p_scale) + + tl.store(y_ptrs, output, mask=block_mask) # output + + # fmt: off -@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) +@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]}) @triton.autotune( configs=_configs, key=["M", "N", "is_fp16"], @@ -67,7 +93,12 @@ def k_dropout_fw( M, N, p, is_fp16, # autotune - **meta, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SIZE_RAND_BLOCK: tl.constexpr, + USE_BIAS: tl.constexpr, + ACTIVATION: tl.constexpr, ): """ Apply dropout on an input tensor @@ -79,10 +110,6 @@ def k_dropout_fw( """ # fmt: on - BLOCK_M = meta["BLOCK_M"] - BLOCK_N = meta["BLOCK_N"] - SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] - row_id = tl.program_id(axis=0) rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) @@ -99,14 +126,16 @@ def k_dropout_fw( rand_mask1, rand_mask2, rand_mask3, rand_mask4 = _get_4_bin_masks(seed, rand_offsets, p) col_mask = cols[None, :] < N - p_scale = 1 / (1 - p) if p < 1. else 1. + p_scale = 1 / (1 - p) - if meta["USE_BIAS"]: + if USE_BIAS: b_ptrs = BIAS + cols[None, :] bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.) + else: + bias = x_ptrs # will not be used + # cycle through the binary masks (workaround / no indexing) for i in range(4): - # cycle through the binary masks (workaround / no indexing) if i == 0: rand_mask = rand_mask1 elif i == 1: @@ -117,29 +146,15 @@ def k_dropout_fw( rand_mask = rand_mask4 block_mask = (rows[:, None] < M) & col_mask - x = tl.load(x_ptrs, mask=block_mask, other=0.) - - # optionally apply a fused bias - if meta["USE_BIAS"]: - x += bias - - # optional: fused activation (while the data is in shared memory) - if meta["ACTIVATION"]: - x = meta["ACTIVATION"](x) + tile_random_drop(x_ptrs, y_ptrs, block_mask, USE_BIAS, bias, rand_mask, p, p_scale, ACTIVATION) - # randomly prune (and scale) the resulting buffer, possibly a no-op - output = _random_prune_and_scale(x, rand_mask, p, p_scale) - - tl.store(y_ptrs, output, mask=block_mask) - - # Update the pointers rows += BLOCK_M # needs to be updated for the mask to be correct x_ptrs += BLOCK_M * stride y_ptrs += BLOCK_M * stride # fmt: off -@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) +@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]}) @triton.autotune( configs=_configs, key=["M", "N", "is_fp16"], @@ -152,7 +167,13 @@ def k_dropout_bw( M, N, p, is_fp16, # autotune - **meta, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SIZE_RAND_BLOCK: tl.constexpr, + TRAINABLE_BIAS: tl.constexpr, + USE_BIAS: tl.constexpr, + ACTIVATION_GRAD: tl.constexpr, ): """ Apply dropout on an input tensor @@ -165,12 +186,6 @@ def k_dropout_bw( """ # fmt: on - BLOCK_M = meta["BLOCK_M"] - BLOCK_N = meta["BLOCK_N"] - SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] - TRAINABLE_BIAS = meta["TRAINABLE_BIAS"] - - rows = tl.arange(0, BLOCK_M) row_id = tl.program_id(axis=0) rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) @@ -190,9 +205,9 @@ def k_dropout_bw( # now go over the tiles grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) col_mask = cols[None, :] < N - p_scale = 1 / (1 - p) if p < 1. else 1. + p_scale = 1 / (1 - p) - if meta["USE_BIAS"]: + if USE_BIAS: b_ptrs = BIAS + cols[None, :] bias = tl.load(b_ptrs, mask=col_mask, other=0.) @@ -211,14 +226,14 @@ def k_dropout_bw( grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.) # optional: fused activation (while the data is in shared memory) - if meta["ACTIVATION_GRAD"]: + if ACTIVATION_GRAD: inputs = tl.load(input_ptrs, mask=block_mask, other=0.) # optionally apply a fused bias - if meta["USE_BIAS"]: + if USE_BIAS: inputs += bias - act_grad = meta["ACTIVATION_GRAD"](inputs).to(grad_out.dtype) + act_grad = ACTIVATION_GRAD(inputs).to(grad_out.dtype) grad_out *= act_grad # randomly prune (and scale) the resulting buffer, possibly a no-op diff --git a/xformers/triton/k_fused_matmul_bw.py b/xformers/triton/k_fused_matmul_bw.py index bfe3229243..b8c8858d1e 100644 --- a/xformers/triton/k_fused_matmul_bw.py +++ b/xformers/triton/k_fused_matmul_bw.py @@ -15,16 +15,16 @@ # fmt: off @triton.heuristics({ - 'EVEN_N': lambda *args, **meta: args[3] % (meta['BLOCK_COL']) == 0, + 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, }) @triton.autotune( configs=[ - triton.Config({"BLOCK_COL": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_COL": 64}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_COL": 128}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_COL": 256}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_COL": 512}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_COL": 1024}, num_stages=3, num_warps=16), + triton.Config({"BLOCK_N": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_N": 64}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=8), ], key=["N"], ) @@ -39,7 +39,9 @@ def kernel_bw( # by to get the element one row down (A has M rows) stride_gom, stride_aim, # Meta-parameters - **META, + BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, + ACTIVATION_GRAD: tl.constexpr, ): # fmt: on @@ -47,9 +49,6 @@ def kernel_bw( Go over all the activation inputs, compute the corresponding gradient """ - # extract metaparameters - BLOCK_N = META["BLOCK_COL"] - # this kernel is relatively simple in terms of scheduling: # - per row (pid_m) # - each program a given chunk on the col axis, @@ -62,16 +61,16 @@ def kernel_bw( act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn # compute the gradient which is related to this activation - if META["EVEN_N"]: + if EVEN_N: act_in = tl.load(act_input_ptrs) else: act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0) - grad_act = META["ACTIVATION_GRAD"](act_in) + grad_act = ACTIVATION_GRAD(act_in) # now read the incoming gradient, the backpropagated one is the multiple of both grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn - if META["EVEN_N"]: + if EVEN_N: grad_out = tl.load(grad_out_ptrs) else: grad_out = tl.load(grad_out_ptrs, mask=rn < N) @@ -120,24 +119,16 @@ def fused_matmul_backward( if act_in is None: act_in = grad_out_ - def grid(META): - return ( - M, - triton.cdiv(N, META["BLOCK_COL"]), - ) + grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) # noqa # fmt: off kernel_bw[grid]( - # data ptrs - grad_act, grad_out_, act_in, - # shapes - N, - # strides - grad_act.stride(0), act_in.stride(0), - weight.stride(0), weight.stride(1), - # optional fused activation - ACTIVATION_GRAD=activation_grad, + grad_act, grad_out_, act_in, # data ptrs + N, # shapes + grad_act.stride(0), act_in.stride(0), # strides + ACTIVATION_GRAD=activation_grad, # optional fused activation ) + # fmt: on # Backpropagation going up, the reference gradient is now # just before the activation diff --git a/xformers/triton/k_fused_matmul_fw.py b/xformers/triton/k_fused_matmul_fw.py index 16fa09f6c6..dbb8c119c0 100644 --- a/xformers/triton/k_fused_matmul_fw.py +++ b/xformers/triton/k_fused_matmul_fw.py @@ -15,24 +15,25 @@ # fmt: off @triton.autotune( configs=[ - triton.Config({"BLOCK_ROW": 16, "BLOCK_COL": 16}, num_stages=5, num_warps=1), - triton.Config({"BLOCK_ROW": 32, "BLOCK_COL": 32}, num_stages=5, num_warps=1), - triton.Config({"BLOCK_ROW": 64, "BLOCK_COL": 32}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_ROW": 32, "BLOCK_COL": 64}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_ROW": 128, "BLOCK_COL": 64}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_ROW": 64, "BLOCK_COL": 128}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_ROW": 128, "BLOCK_COL": 128}, num_stages=4, num_warps=4), - # triton.Config({"BLOCK_ROW": 32, "BLOCK_COL": 256}, num_stages=3, num_warps=4), - # triton.Config({"BLOCK_ROW": 256, "BLOCK_COL": 32}, num_stages=3, num_warps=4), - # triton.Config({"BLOCK_ROW": 64, "BLOCK_COL": 256}, num_stages=3, num_warps=8), - # triton.Config({"BLOCK_ROW": 256, "BLOCK_COL": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 16}, num_stages=5, num_warps=1), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=5, num_warps=1), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=4), + # requires a GPU with enough shared memory + # triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_stages=3, num_warps=4), + # triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_stages=3, num_warps=4), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=3, num_warps=8), + # triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_stages=3, num_warps=8), ], key=["M", "N", "K"], ) @triton.jit def kernel_fma( # Pointers to matrices - OUT, ACT_INPUTS, INPUT, WEIGHT, BIAS, + OUT, ACT_INPUTS, INPUT, WEIGHT, bias, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 @@ -41,7 +42,11 @@ def kernel_fma( stride_om, stride_im, stride_wn, stride_wk, # Meta-parameters - **META, + BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BIAS: tl.constexpr, + SAVE_ACT_INPUTS: tl.constexpr, + ACTIVATION: tl.constexpr, ): # fmt: on @@ -59,10 +64,6 @@ def kernel_fma( This kernel will consolidate over K """ - # extract metaparameters - BLOCK_M, GROUP_M = META["BLOCK_ROW"], META["GROUP_ROW"] - BLOCK_N, BLOCK_K = META["BLOCK_COL"], META["BLOCK_K"] - # programs are grouped together to improve L2 hit rate # the logic is that we'll consolidate over K. If the programs were not grouped, # then multiple cols/rows in the result would end up pulling in the same row and lines @@ -98,8 +99,8 @@ def kernel_fma( # initialize and iteratively update accumulator acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if META["BIAS"]: - bias = tl.load(BIAS + rn, mask=rn < N, other=0.0).to(tl.float32) + if BIAS: + bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) acc += bias[None, :] # block level matrix multiplication. @@ -114,13 +115,13 @@ def kernel_fma( weight_ptrs += BLOCK_K * stride_wk # optional: save the activation inputs - if META["SAVE_ACT_INPUTS"]: + if SAVE_ACT_INPUTS: act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :] tl.store(act_in_ptrs, acc, mask=(rm[:, None] < M) & (rn[None, :] < N)) # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION"]: - acc = META["ACTIVATION"](acc) + if ACTIVATION: + acc = ACTIVATION(acc) # write back result out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :] @@ -160,28 +161,18 @@ def fused_matmul( act_inputs = torch.empty_like(outputs) if save_act_inputs else x # will not be used in that case # 1D launch kernel where each block gets its own program. - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_ROW"]) * triton.cdiv(N, META["BLOCK_COL"]), - ) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa # fmt: off kernel_fma[grid]( - # data ptrs - outputs, act_inputs, x_, weight, - bias if bias is not None else x, # auto skip bias if not present - # shapes - M, N, K, - # strides - outputs.stride(0), x_.stride(0), + outputs, act_inputs, x_, weight, # data ptrs + bias if bias is not None else x, # auto skip bias if not present + M, N, K, # shapes + outputs.stride(0), x_.stride(0), # strides weight.stride(0), weight.stride(1), - # optional fused activation - ACTIVATION=activation, - # optional fused bias - BIAS=bias is not None, - # speed optimization: group the programs - # improve on data reuse in L2 cache - GROUP_ROW=8, + ACTIVATION=activation, # optional fused activation + BIAS=bias is not None, # optional fused bias + GROUP_M=8, # speed optimization: group the programs BLOCK_K=32, SAVE_ACT_INPUTS=save_act_inputs ) diff --git a/xformers/triton/k_layer_norm.py b/xformers/triton/k_layer_norm.py index 63244475d5..cc1c804445 100644 --- a/xformers/triton/k_layer_norm.py +++ b/xformers/triton/k_layer_norm.py @@ -12,32 +12,9 @@ import triton.language as tl +# fmt: off @triton.jit -def _affine(W, B, N, x, META): - cols = tl.arange(0, META["BLOCK_SIZE_N"]) - - w = tl.load(W + cols, mask=cols < N, other=1.0) - zero = 0.0 - zero = zero.to(w.dtype) # Triton bug workarounds - w = tl.where(cols < N, w, zero) - - b = tl.load(B + cols, mask=cols < N, other=0.0) - b = tl.where(cols < N, b, zero) - y = x * w + b - return y - - -@triton.jit -def _store(y, Y, stride, N, META): - row = tl.program_id(0) - cols = tl.arange(0, META["BLOCK_SIZE_N"]) - - y_ptrs = Y + row * stride + cols - tl.store(y_ptrs, y, mask=cols < N) - - -@triton.jit -def layer_norm_non_affine(X, M, V, stride, N, eps, META): +def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # fmt: on """ Fused layernorm kernel over a 3d tensor. @@ -48,48 +25,33 @@ def layer_norm_non_affine(X, M, V, stride, N, eps, META): """ row = tl.program_id(0) - cols = tl.arange(0, META["BLOCK_SIZE_N"]) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N # Move to this row x_ptrs = X + row * stride + cols - x = tl.load(x_ptrs, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x, 0.0) # Triton bug workarounds - - # Compute variance - x_mean = tl.sum(x, axis=0) / N - x_zm = x - x_mean - x_zm = tl.where(cols < N, x_zm, 0.0) # Triton bug workaround - x_var = tl.sum(x_zm * x_zm, axis=0) / N - x_inv_sigma = 1.0 / tl.sqrt(x_var + eps) - - # write-back per sample mean/rstd, used in the backward pass - tl.store(M + row, x_mean) - tl.store(V + row, x_inv_sigma) - - return x_zm * x_inv_sigma + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + # Compute mean and variance + mean = tl.sum(x, axis=0) / N + x_zm = tl.where(mask, x - mean, 0.0) + tl.store(M + row, mean) -# fmt: off -@triton.jit -def layer_norm_non_affine_fw(X, Y, M, V, stride, N, eps, **META): - _store(layer_norm_non_affine(X, M, V, stride, N, eps, META), Y, stride, N, META) + x_var = tl.sum(x_zm * x_zm, axis=0) / N + rstd = 1.0 / tl.sqrt(x_var + eps) + # Normalize, optionally affine + y = x_zm * rstd + tl.store(V + row, rstd) -# fmt: off -@triton.jit -def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META): - # fmt: on - """ - Fused layernorm kernel over a 3d tensor. - The layer norm is applied over the last dimension. + mask = cols < N + if affine: + w = tl.load(W + cols, mask=mask, other=1.0) + b = tl.load(B + cols, mask=mask, other=0.0) + y = y * w + b - Compute - y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta - """ - y = layer_norm_non_affine(X, M, V, stride, N, eps, META) - y = _affine(W, B, N, y, META) - - _store(y, Y, stride, N, META) + y_ptrs = Y + row * stride + cols + tl.store(y_ptrs, y, mask=mask) # Backward pass (DX + partial DW + partial DB) @@ -97,138 +59,121 @@ def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META): @triton.jit def layer_norm_bwd_dx_fused( DX, DY, DW, DB, - Y, W, B, V, + X, W, M, V, Lock, stride, N, - **META + # META-parameters + affine: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, ): # fmt: on - GROUP_SIZE_M = META["GROUP_SIZE_M"] - BLOCK_SIZE_N = META["BLOCK_SIZE_N"] - # position of elements processed by this program row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N # offset data pointers to start at the row of interest - y_ptrs = Y + row * stride + cols + x_ptrs = X + row * stride + cols dy_ptrs = DY + row * stride + cols - w_ptrs = W + cols - b_ptrs = B + cols - - # offset locks and weight/bias gradient pointer - # each kernel instance accumulates partial sums for - # DW and DB into one of GROUP_SIZE_M independent buffers - # these buffers stay in the L2, which allow this kernel - # to be fast - lock_id = row % GROUP_SIZE_M - Lock += lock_id - Count = Lock + GROUP_SIZE_M # load data to SRAM - y = tl.load(y_ptrs, mask=cols < N, other=0).to(tl.float32) - dy = tl.load(dy_ptrs, mask=cols < N, other=0).to(tl.float32) - w = tl.load(w_ptrs, mask=cols < N, other=0).to(tl.float32) - b = tl.load(b_ptrs, mask=cols < N, other=0).to(tl.float32) - + x = tl.load(x_ptrs, mask=mask, other=0) + dy = tl.load(dy_ptrs, mask=mask, other=0) + mean = tl.load(M + row) rstd = tl.load(V + row) # compute dx - xhat = (y - b) / w - wdy = w * dy - xhat = tl.where(cols < N, xhat, 0.0) - wdy = tl.where(cols < N, wdy, 0.0) - mean1 = tl.sum(xhat * wdy, axis=0) / N - mean2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * mean1 + mean2)) * rstd - - # write-back dx - _store(dx, DX, stride, N, META) - - # accumulate partial sums for dw/db - partial_dw = (dy * xhat).to(w.dtype) - partial_db = dy.to(w.dtype) - - # - wait for a lock on the accumulated dw/db - while tl.atomic_cas(Lock, 0, 1) == 1: - pass - count = tl.load(Count) - - # - we got the lock, accumulate this kernel's results with - # the stored values. - dw_ptrs = DW + lock_id * N + cols - db_ptrs = DB + lock_id * N + cols + xhat = (x - mean) * rstd - if count == 0: - # first store doesn't accumulate - tl.atomic_xchg(Count, 1) + if affine: + w = tl.load(W + cols, mask=mask, other=0) + wdy = w * dy else: - partial_dw += tl.load(dw_ptrs, mask=cols < N, other=0.) - partial_db += tl.load(db_ptrs, mask=cols < N, other=0.) - - tl.store(dw_ptrs, partial_dw, mask=cols < N) - tl.store(db_ptrs, partial_db, mask=cols < N) - - # release lock - tl.atomic_xchg(Lock, 0) - - -@triton.jit -def layer_norm_no_affine_bwd( - DX, DY, - Y, V, - stride, N, - **META -): - # fmt: on - - # position of elements processed by this program - row = tl.program_id(0) - cols = tl.arange(0, META["BLOCK_SIZE_N"]) + wdy = dy - # offset data pointers to start at the row of interest - y_ptrs = Y + row * stride + cols - dy_ptrs = DY + row * stride + cols - - # load data to SRAM - y = tl.load(y_ptrs, mask=cols < N, other=0).to(tl.float32) - dy = tl.load(dy_ptrs, mask=cols < N, other=0).to(tl.float32) - - rstd = tl.load(V + row) - - # compute dx - xhat = tl.where(cols < N, y, 0.0) - wdy = tl.where(cols < N, dy, 0.0) + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) mean1 = tl.sum(xhat * wdy, axis=0) / N mean2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * mean1 + mean2)) * rstd # write-back dx - _store(dx, DX, stride, N, META) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N # re-materialize the mask to save registers + dx_ptrs = DX + row * stride + cols + tl.store(dx_ptrs, dx, mask=mask) + + if affine: + # accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + partial_db = dy.to(w.dtype) + + # offset locks and weight/bias gradient pointer + # each kernel instance accumulates partial sums for + # DW and DB into one of GROUP_SIZE_M independent buffers + # these buffers stay in the L2, which allow this kernel + # to be fast + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + + # - wait for a lock on the accumulated dw/db + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + + # - we got the lock, accumulate this kernel's results with + # the stored values. + dw_ptrs = DW + lock_id * N + cols + db_ptrs = DB + lock_id * N + cols + + if count == 0: + # first store doesn't accumulate + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(dw_ptrs, mask=mask, other=0.) + partial_db += tl.load(db_ptrs, mask=mask, other=0.) + + tl.store(dw_ptrs, partial_dw, mask=mask) + tl.store(db_ptrs, partial_db, mask=mask) + + # release lock + tl.atomic_xchg(Lock, 0) # Backward pass (total DW + total DB) # fmt: off @triton.jit -def layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta): +def layer_norm_bwd_dwdb( + DW, DB, FINAL_DW, FINAL_DB, + M, N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr +): # fmt: on + pid = tl.program_id(0) - BLOCK_SIZE_M = meta["BLOCK_SIZE_M"] - BLOCK_SIZE_N = meta["BLOCK_SIZE_N"] cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cols = cols < N + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, meta["BLOCK_SIZE_M"]) + rows = i + tl.arange(0, BLOCK_SIZE_M) offs = rows[:, None] * N + cols[None, :] + mask_rm = rows < M - dw += tl.load(DW + offs, mask=(rows[:, None] < M) & (cols[None, :] < N), other=0.0) - db += tl.load(DB + offs, mask=(rows[:, None] < M) & (cols[None, :] < N), other=0.0) + dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0) + db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) - tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) - tl.store(FINAL_DB + cols, sum_db, mask=cols < N) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cols = cols < N + + tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols) + tl.store(FINAL_DB + cols, sum_db, mask=mask_cols) diff --git a/xformers/triton/k_softmax.py b/xformers/triton/k_softmax.py index 67c28dd1d0..09e35c8034 100644 --- a/xformers/triton/k_softmax.py +++ b/xformers/triton/k_softmax.py @@ -12,8 +12,8 @@ # and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html -def get_depth(*args, **_): - return triton.next_power_of_2(args[-1]) +def get_depth(args): + return triton.next_power_of_2(args["K"]) # autotune: Triton will test out these configurations, and automatically pick the fastest one. @@ -30,7 +30,7 @@ def get_depth(*args, **_): ], key=["K"], ) -@triton.heuristics(values={"depth": get_depth , "is_fp16": lambda *args, **_: args[0].dtype == torch.float16}) +@triton.heuristics(values={"depth": get_depth , "is_fp16": lambda args: args["Y"].dtype == torch.float16}) @triton.jit def _softmax( Y, X, M, @@ -38,7 +38,12 @@ def _softmax( stride_xm, stride_xn, stride_mn, K, - **meta, # extra parameters which can be automatically filled in given some heuristics + # Meta-params + depth: tl.constexpr, + causal: tl.constexpr, + use_mask: tl.constexpr, + is_fp16: tl.constexpr, + log: tl.constexpr, ): # fmt: om @@ -54,7 +59,7 @@ def _softmax( n = tl.program_id(1) # col indices - k = tl.arange(0, meta["depth"]) + k = tl.arange(0, depth) # the memory address of all the elements that we want to load can be computed as follows x_ptrs = X + m * stride_xm + n * stride_xn + k @@ -63,18 +68,18 @@ def _softmax( io_mask = k < K # Causal - 1: skip on the loads directly - if meta["causal"]: + if causal: io_mask = io_mask & (k <= n) x = tl.load(x_ptrs, mask=io_mask, other=float("-inf")) # Causal - 2: enforce correctness over a couple of misloaded values - if meta["causal"]: + if causal: off = float("-inf") - off = off.to(x.dtype) + off = off.to(x.dtype) # type: ignore x = tl.where(k > n, off, x) - if meta["use_mask"]: + if use_mask: mask_ptrs = M + n * stride_mn + k add_mask = tl.load(mask_ptrs, io_mask, other=float("-inf")) x += add_mask @@ -82,7 +87,7 @@ def _softmax( # compute numerically-stable softmax z = x - tl.max(x, axis=0) - if meta["is_fp16"]: + if is_fp16: # tl.exp() crashes on fp16 values # See https://github.com/openai/triton/issues/241 z = z.to(tl.float32) @@ -90,7 +95,7 @@ def _softmax( num = tl.exp(z) denom = tl.sum(num, axis=0) - if meta["log"]: + if log: y = z - tl.log(denom) else: y = num / denom @@ -115,7 +120,7 @@ def _softmax( ], key=["K"], ) -@triton.heuristics(values={"is_fp16": lambda *args, **_: args[0].dtype == torch.float16}) +@triton.heuristics(values={"is_fp16": lambda args: args["GradIn"].dtype == torch.float16}) @triton.jit def _softmax_backward( GradIn, GradOut, Out, @@ -123,7 +128,11 @@ def _softmax_backward( stride_gm, stride_gn, stride_om, stride_on, K, - **meta, + # meta-params + depth: tl.constexpr, + causal: tl.constexpr, + is_fp16: tl.constexpr, + log: tl.constexpr, ): # fmt: on @@ -136,7 +145,7 @@ def _softmax_backward( n = tl.program_id(1) # col indices - k = tl.arange(0, meta["depth"]) + k = tl.arange(0, depth) # the memory address of all the elements that we want to load can be computed as follows grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k @@ -146,22 +155,22 @@ def _softmax_backward( io_mask = k < K # Causal - 1: skip on the loads directly - if meta["causal"]: + if causal: io_mask = io_mask & (k <= n) g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0)) o = tl.load(out_ptrs, mask=io_mask, other=float(0)) # Causal - 2: enforce correctness over a couple of misloaded values - if meta["causal"]: + if causal: zero = float(0) - zero = zero.to(g.dtype) + zero = zero.to(g.dtype) # type: ignore g = tl.where(k > n, zero, g) o = tl.where(k > n, zero, o) - if meta["log"]: + if log: s = tl.sum(g, 0) - if meta["is_fp16"]: + if is_fp16: o = o.to(tl.float32) grad_in = g - tl.exp(o) * s else: diff --git a/xformers/triton/k_sum.py b/xformers/triton/k_sum.py index 4d3d6ebb26..339c368810 100644 --- a/xformers/triton/k_sum.py +++ b/xformers/triton/k_sum.py @@ -14,7 +14,9 @@ def k_sum_0( stride_xm, M, N, is_fp16, - **meta, + # META-params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, ): # fmt: om @@ -22,8 +24,6 @@ def k_sum_0( Sum a 2d tensor over the first (strided) dimension. This extracts some speed through a parallel sum across the second dimension """ - BLOCK_M = meta["BLOCK_M"] - BLOCK_N = meta["BLOCK_N"] # partial row indices. We'll reduce over this dimension m = tl.arange(0, BLOCK_M) diff --git a/xformers/triton/layer_norm.py b/xformers/triton/layer_norm.py index 4866561063..520262997b 100644 --- a/xformers/triton/layer_norm.py +++ b/xformers/triton/layer_norm.py @@ -18,8 +18,6 @@ layer_norm_bwd_dwdb, layer_norm_bwd_dx_fused, layer_norm_fw, - layer_norm_no_affine_bwd, - layer_norm_non_affine_fw, ) _triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32 @@ -30,6 +28,10 @@ class _LayerNorm(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None) def forward(ctx, x, weight, bias, eps): + # catch eps being too small if the tensors are fp16 + if x.dtype == torch.float16: + eps = max(eps, 1.6e-5) + # allocate output y = torch.empty_like(x) @@ -61,35 +63,24 @@ def forward(ctx, x, weight, bias, eps): y = y.contiguous() # heuristics for number of warps. - num_warps = min(max(BLOCK_SIZE_N // 256, 1), 8) + num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16) # enqueue kernel # fmt: off - if weight is None: - layer_norm_non_affine_fw[(M,)]( - x_arg, y, mean, rstd, - x_arg.stride(0), - N, - eps, - num_warps=num_warps, - BLOCK_SIZE_N=BLOCK_SIZE_N - ) - else: - layer_norm_fw[(M,)]( - x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), - N, - eps, - num_warps=num_warps, - BLOCK_SIZE_N=BLOCK_SIZE_N - ) + layer_norm_fw[(M,)]( + x_arg, y, weight, bias, mean, rstd, + x_arg.stride(0), + N, + eps, + num_warps=num_warps, + BLOCK_SIZE_N=BLOCK_SIZE_N, + affine=weight is not None + ) # fmt: on - ctx.save_for_backward(y, rstd, weight, bias) + ctx.save_for_backward(x, mean, rstd, weight) ctx.BLOCK_SIZE_N = BLOCK_SIZE_N ctx.num_warps = num_warps - ctx.eps = eps - ctx.N = N return y.reshape_as(x) @@ -98,59 +89,60 @@ def forward(ctx, x, weight, bias, eps): def backward( ctx, dy ): # pragma: no cover # this is covered, but called directly from C++ - y, var, weight, bias = ctx.saved_tensors + x, mean, rstd, weight = ctx.saved_tensors + + # flatten the batch dimension, if any. + # We're interested in 'samples' x norm_dimension + x = x.reshape(-1, x.size(-1)) + M, N = x.size() # heuristics for amount of parallel reduction stream for DG/DB - N = y.size(-1) - GROUP_SIZE_M = 64 + GROUP_SIZE_M = 32 if N <= 8192: - GROUP_SIZE_M = 96 + GROUP_SIZE_M = 64 if N <= 4096: + GROUP_SIZE_M = 96 + if N <= 2048: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 - # flatten the batch dimension, if any. - # We're interested in 'samples' x norm_dimension - y = y.reshape(-1, y.size(-1)) - M, N = y.size() + if dy.dtype == torch.float32: + GROUP_SIZE_M = GROUP_SIZE_M // 2 # allocate output locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda") - t_args = {"dtype": y.dtype, "device": y.device} - _dw = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args) - _db = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args) - dw = torch.empty((y.size(-1),), **t_args) - db = torch.empty((y.size(-1),), **t_args) + t_args = {"dtype": x.dtype, "device": x.device} + _dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args) + _db = torch.empty_like(_dw) + dw = torch.empty((x.size(-1),), **t_args) + db = torch.empty_like(dw) dy = dy.contiguous() dx = torch.empty_like(dy) # Check the tensor shapes and layouts # we suppose in the kernel that they have the same size and are contiguous assert ( - dx.numel() == y.numel() + dy.numel() == x.numel() ), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm" # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB + num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16) # fmt: off - meta = {"BLOCK_SIZE_N": ctx.BLOCK_SIZE_N, - "GROUP_SIZE_M": GROUP_SIZE_M, - "num_warps": ctx.num_warps} - if weight is None: - layer_norm_no_affine_bwd[(M,)](dx, dy, y, var, y.stride(0), N, **meta) - return dx, None, None, None - layer_norm_bwd_dx_fused[(M,)]( - dx, dy, _dw, _db, - y, weight, bias, var, + dx, dy, _dw, _db, x, + weight if weight is not None else x, + mean, rstd, locks, - y.stride(0), + x.stride(0), N, - **meta + affine=weight is not None, + GROUP_SIZE_M=GROUP_SIZE_M, + BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, + num_warps=num_warps ) - # fmt: on def grid(meta): @@ -163,7 +155,7 @@ def grid(meta): GROUP_SIZE_M, N, BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128 + BLOCK_SIZE_N=64 ) # fmt: on diff --git a/xformers/triton/softmax.py b/xformers/triton/softmax.py index 8c1d4afdba..753f1ddfd6 100644 --- a/xformers/triton/softmax.py +++ b/xformers/triton/softmax.py @@ -5,7 +5,6 @@ import logging -from enum import Enum from typing import Optional import torch @@ -22,11 +21,6 @@ _triton_registered_warnings = False -class MaskType(str, Enum): - ADD = "add" - MUL = "mul" - - # Helper to handle the SPMD launch grid and error cases class _softmax_triton(torch.autograd.Function): @staticmethod diff --git a/xformers/utils.py b/xformers/utils.py index f2db7e808e..7a7dd44bae 100644 --- a/xformers/utils.py +++ b/xformers/utils.py @@ -91,7 +91,7 @@ def rmf(filename: str) -> None: @contextlib.contextmanager def temp_files_ctx(num: int) -> Generator: - """ A context to get tempfiles and ensure they are cleaned up. """ + """A context to get tempfiles and ensure they are cleaned up.""" files = [tempfile.mkstemp()[1] for _ in range(num)] yield tuple(files)