diff --git a/BENCHMARKS.md b/BENCHMARKS.md index 56c21a3e0d..9bd0cecb8f 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -58,7 +58,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/ ![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png) -![Fused linear layers throughput in fp16 - training](docs/plots/fused_linea/FusedLinear_fp16_FW_BW_gelu.png) +![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png) ![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png) @@ -74,7 +74,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/ ![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_none.png) -![Fused linear layers throughput in fp16 - training](docs/plots/fused_line/FusedLinear_fp16_FW_BW_none.png) +![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png) ### Fused layer norm @@ -89,18 +89,25 @@ Note that in the Triton case the slowdowns at extreme sizes are because of regis ![Fused layer norm throughput in fp32 - training](docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png)) -### Fused dropout + bias +### Fused dropout + bias + activation -You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.10. +You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10. -![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png) +![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png) -![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png)) +![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png)) -![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png)) +![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png)) -![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png)) +![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png)) +![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png) + +![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png)) + +![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png)) + +![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png)) ## LRA diff --git a/CHANGELOG.md b/CHANGELOG.md index 66c2cf20cd..e82850914c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## TBD +### Fixed +- Much faster fused dropout [#164] + ## [0.0.7] - 2021-11-30 ### Fixed - Dropout setting not properly passed in many attentions [#123] diff --git a/README.md b/README.md index 011ece45be..32b991b687 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* 5. Hackable 1. Not using monolithic CUDA kernels, composable building blocks 2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible + 3. Native support for SquaredReLU (on top of ReLU, LeakyReLU, GeLU, ..), extensible activations ### FAQ ? diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16.png deleted file mode 100644 index 77a4c8172d..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png index fb03c2659b..b6daad60b1 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png index eb252306a4..a6ce46bd11 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png new file mode 100644 index 0000000000..53551bfb50 Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32.png deleted file mode 100644 index 9b6c9fa613..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png index 6108736d6a..c38a9920ee 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png index f5506ef73c..3fab1aa5c9 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png new file mode 100644 index 0000000000..87aaccf6cb Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16.png deleted file mode 100644 index d6742b7033..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png index 5ba3897ed3..77a3ca064a 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png index f98c6c0f4f..c9aa6fe7f6 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png new file mode 100644 index 0000000000..f0ebe52763 Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32.png deleted file mode 100644 index e204602861..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png index 10072eb2fe..55b709cc3c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png index e2f11bdb1c..6ff747c84e 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png new file mode 100644 index 0000000000..18ddbe839d Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png deleted file mode 100644 index d1f4d04866..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png index e2218061c6..c180cd479c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png index 0f3e342c34..6cd37c7822 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png new file mode 100644 index 0000000000..baa5b1a640 Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png deleted file mode 100644 index 6d9bfebb2f..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png index f922504d21..3764396b94 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png index dbddfd81c0..800e2f4a28 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png new file mode 100644 index 0000000000..5017cf7266 Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png deleted file mode 100644 index c7542dc2d4..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png index 61f526809c..f9b9ad27ef 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png index 4749dbb2c4..d4e4e06d38 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png new file mode 100644 index 0000000000..169e0304aa Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png deleted file mode 100644 index b1af20e791..0000000000 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png and /dev/null differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png index bcb2e585b0..c161c0e121 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png index eec12679dd..ed60dd3363 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png new file mode 100644 index 0000000000..65522a30e1 Binary files /dev/null and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png differ diff --git a/examples/microGPT.py b/examples/microGPT.py index abc5370a6a..fe664d1c1c 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -311,9 +311,9 @@ def top_k_logits(logits, k): gpus=1, max_epochs=EPOCHS, precision=16, - gradient_clip_val=1, + gradient_clip_val=1, # Use to catch divergent gradients, if experimenting log_every_n_steps=1, - detect_anomaly=True, + # detect_anomaly=True, # Use to catch NaNs, if experimenting accumulate_grad_batches=REF_BATCH // BATCH, ) diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 74e0bc73b2..9068a8e969 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -44,6 +44,15 @@ def test_dropout_cpu(): x = torch.normal(0, 1, size=(16, 16), device="cpu") _ = triton_dropout(x) + # Check eval means no dropout + triton_dropout.eval() + y = triton_dropout(x) + assert y.count_nonzero() == y.numel() + + triton_dropout.train() + y = triton_dropout(x) + assert y.count_nonzero() != y.numel() + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( @@ -53,7 +62,8 @@ def test_dropout_cpu(): @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("amp", [False, True]) @pytest.mark.parametrize("bias", [False, True]) -def test_dropout(shape, amp, bias): +@pytest.mark.parametrize("p", [0, 0.1, 0.5]) +def test_dropout(shape, amp, bias, p): """ Check some basic dropout properties """ @@ -97,6 +107,11 @@ def test_dropout(shape, amp, bias): == y.shape[1] ) + # Check that the drop probability is about right + y = triton_dropout(x, p=p) + drop_p = (y.numel() - y.count_nonzero()) / y.numel() + assert abs(drop_p - p) < 0.01 + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( @@ -107,7 +122,7 @@ def test_dropout(shape, amp, bias): @pytest.mark.parametrize("amp", [False, True]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("activation", [a.value for a in Activation]) -@pytest.mark.parametrize("p", [0, 0.001, 0.5]) +@pytest.mark.parametrize("p", [0, 0.01, 0.5]) def test_dropout_parity(shape, amp, bias, activation, p): """ Check some basic dropout properties @@ -158,4 +173,4 @@ def test_dropout_parity(shape, amp, bias, activation, p): if bias: assert torch.allclose( torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01 - ), f"{b.grad.norm()}\n{b_.grad.norm()}" + ), f"{b.grad.norm()} - {b_.grad.norm()}" diff --git a/xformers/benchmarks/benchmark_triton_dropout.py b/xformers/benchmarks/benchmark_triton_dropout.py index 8aa05a33c1..e7e3bf738b 100644 --- a/xformers/benchmarks/benchmark_triton_dropout.py +++ b/xformers/benchmarks/benchmark_triton_dropout.py @@ -62,12 +62,14 @@ def torch_step(x): y = torch_act(y) if backward: + y.grad = None torch.norm(y).backward() return y def triton_step(x): y = triton_dropout(x) if backward: + y.grad = None torch.norm(y).backward() return y @@ -85,7 +87,9 @@ def triton_step(x): ), ), ]: - time = triton.testing.do_bench(lambda: testcase.function(a))[0] + time = triton.testing.do_bench( + lambda: testcase.function(a), grad_to_none=[a, b] + )[0] key = f"B={B}, M={M}, K={K}" if key not in results: results[key] = {} @@ -105,7 +109,7 @@ def triton_step(x): ) -for activation in [Activation.GeLU, None]: +for activation in [Activation.GeLU, None, Activation.SquaredReLU]: for bw in [True, False]: for bias in [True, False]: bench_dropout(bias, bw, activation) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index d8ab2e3860..365d010fa1 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -28,12 +28,12 @@ def pretty_print(results, title, units): """ Printout the contents of a dict as a human-readable and Markdown compatible array""" print(title) - header = " Units: {:<40}".format(units) - print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) + header = " Units: {:<45}".format(units) + print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) offset = len(header) print( - "|{}|".format("-" * offset) + "|-{}|".format("-" * offset) + "".join("{}|".format("-" * 20) for _ in results.keys()) ) @@ -44,7 +44,7 @@ def pretty_print(results, title, units): for k, w in workloads.items(): print( - "|{0:<{offset}}|".format(k, offset=offset) + "| {0:<{offset}}|".format(k, offset=offset) + "".join("{:<20}|".format(v) for v in w) ) @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.xticks(rotation=45) plt.savefig(filename, bbox_inches="tight") - plt.clf() + plt.close(f) if _triton_is_available: diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index ec4cb09bd5..6e2ded2316 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -19,41 +19,52 @@ get_triton_activation_kernel, ) from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw -from xformers.triton.sum_strided import sum_2d_dim_0 + +GROUP_M = 32 +BLOCK_M = GROUP_M // 4 +BLOCK_N = 128 # Helper to handle the SPMD launch grid and error cases class _dropout(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, p, bias, activation, activation_grad): + def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias): # Soft-flatten an hypothetical 3rd dimension x_ = x.reshape(-1, x.shape[-1]).contiguous() y = torch.empty_like(x_) - _, N = x_.shape - - assert bias is None or bias.dtype == x.dtype, bias + M, N = x_.shape - # Generate one seed per sample - # seed max is int32 max for positive numbers: 2**16 - seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32) + assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N) - # SPMD launch grid def grid(meta): + # NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for + # a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks + # but we need to take this factor of 4 into account when scheduling all the kernels return ( - x_.shape[0], - triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]), + triton.cdiv(M, meta["BLOCK_M"] * 4), + triton.cdiv(N, meta["BLOCK_N"]), ) + N_BLOCK_N = triton.cdiv(N, BLOCK_N) + + # Generate one seed per sample + # seed max is int32 max for positive numbers: 2**16 + seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32) + # fmt: off k_dropout_fw[grid]( - y, x_, bias if bias is not None else x_, + y, x_, + bias if bias is not None else x_, seeds, y.stride(0), - N, + M, N, p, + x.dtype == torch.float16, USE_BIAS=bias is not None, - ACTIVATION=activation + ACTIVATION=activation, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, ) # fmt: on @@ -61,7 +72,8 @@ def grid(meta): ctx.save_for_backward(seeds, bias, x) else: ctx.save_for_backward(seeds, bias, None) - ctx.trainable_bias = bias is not None + + ctx.trainable_bias = bias is not None and trainable_bias ctx.activation_grad = activation_grad ctx.p = p @@ -76,7 +88,7 @@ def backward(ctx, grad_out): grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous() grad_in = torch.empty_like(grad_out_) - _, N = grad_out_.shape + M, N = grad_out_.shape # Optional inputs to compute the activation contribution to the gradient assert inputs is not None or ctx.activation_grad is None @@ -84,32 +96,63 @@ def backward(ctx, grad_out): if inputs is None: inputs = grad_out_ elif inputs.ndim > 2: - inputs = inputs.reshape(-1, grad_out.shape[-1]) + inputs = inputs.reshape(-1, N) + + # We split the problem in tiles: + # - over M there will be a follow up reduction + # - over M, we go by 4 tiles at at time (consequence of the random number generation) + # - over N we compromise in between trying to use as much memory paralellism as possible, + # (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too + # big because of register spilling + N_BLOCKS_M = triton.cdiv(M, GROUP_M) + + if ctx.trainable_bias: + grad_bias = torch.empty( + ( + N_BLOCKS_M, + N, + ), + device=grad_in.device, + dtype=grad_in.dtype, + ) + + else: + grad_bias = grad_in # will not be used - # SPMD launch grid def grid(meta): + # NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for + # a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks + # but we need to take this factor of 4 into account when scheduling all the kernels return ( - grad_out_.shape[0], - triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]), + triton.cdiv(M, meta["BLOCK_M"] * 4), + triton.cdiv(N, meta["BLOCK_N"]), ) # fmt: off k_dropout_bw[grid]( - grad_in, grad_out_, inputs, bias if bias is not None else inputs, + grad_in, grad_bias, grad_out_, + inputs, bias if bias is not None else inputs, seeds, grad_out_.stride(0), inputs.stride(0), - N, + M, N, ctx.p, + grad_in.dtype == torch.float16, USE_BIAS=bias is not None, - ACTIVATION_GRAD=ctx.activation_grad) + ACTIVATION_GRAD=ctx.activation_grad, + TRAINABLE_BIAS=ctx.trainable_bias, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) # fmt: on - if ctx.trainable_bias: - grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in) - else: - grad_bias = None - - return grad_in.reshape_as(grad_out), None, grad_bias, None, None + return ( + grad_in.reshape_as(grad_out), + None, + torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None, + None, + None, + None, + ) def dropout( @@ -129,7 +172,14 @@ def dropout( act_kernel = get_triton_activation_kernel(activation) act_grad_kernel = get_triton_activation_bwd_kernel(activation) - return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel) + return _dropout.apply( + x, + p, + bias, + act_kernel, + act_grad_kernel, + bias is not None and bias.requires_grad, + ) class FusedDropoutBias(torch.nn.Module): @@ -142,10 +192,13 @@ def __init__( super().__init__() self.p = p self.activation_type = activation - self.register_buffer( - "bias", torch.zeros(bias_shape) if bias_shape is not None else None + self.bias = ( + torch.zeros(bias_shape, requires_grad=True) + if bias_shape is not None + else None ) self.activation = get_triton_activation_kernel(activation) + self.pytorch_activation = build_activation(self.activation_type) self.activation_grad = get_triton_activation_bwd_kernel(activation) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -153,12 +206,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.bias is not None: # type: ignore self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore + # Train/inference + p = self.p if self.training else 0.0 + + # This kernel is slower than pytorch for small buffers, bypassing it in that case + perf_check = x.shape[-1] > 512 + # Catch a non-cuda setup, fallback to pytorch - if not x.is_cuda: - activation = build_activation(self.activation_type) + if not x.is_cuda or not perf_check: x = x + self.bias if self.bias is not None else x - x = activation(x) - return torch.nn.functional.dropout(x, self.p) + x = self.pytorch_activation(x) + return torch.nn.functional.dropout(x, p) - p = self.p if self.training else 0.0 - return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad) + # The normal, Triton-backed path + return _dropout.apply( + x, p, self.bias, self.activation, self.activation_grad, True + ) diff --git a/xformers/triton/k_activations.py b/xformers/triton/k_activations.py index 0964096d6c..31049101c9 100644 --- a/xformers/triton/k_activations.py +++ b/xformers/triton/k_activations.py @@ -64,8 +64,7 @@ def relu(x): .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html """ zero = 0.0 - zero = zero.to(x.dtype) - return tl.where(x >= 0, x, zero) + return tl.where(x >= 0, x, zero.to(x.dtype)) @triton.jit @@ -74,10 +73,8 @@ def relu_grad(x): # in that it does not require the input to retrospectively compute its gradient # here the input is the downstream gradient, and we return the upstream gradient directly zero = 0.0 - zero = zero.to(x.dtype) one = 1.0 - one = one.to(x.dtype) - return tl.where(x >= 0, one, zero) + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) @triton.jit @@ -88,7 +85,7 @@ def squared_relu(x): .. _Primer: https://arxiv.org/abs/2109.08668 """ x_ = relu(x) - return x_ * x_ + return (x_ * x_).to(x.dtype) @triton.jit diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index 61878840f9..2af660ffd9 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -10,138 +10,234 @@ import triton import triton.language as tl -_k_configs = [ - triton.Config({"BLOCK_SIZE": 128}, num_warps=1), - triton.Config({"BLOCK_SIZE": 512}, num_warps=2), - triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), - triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), +_configs = [ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), ] @triton.jit -def _drop_and_scale(SEEDS, row, p, offsets, x): - # randomly prune the weights - seed = SEEDS + row - random = tl.rand(seed.to(tl.int32), offsets) - x_keep = random > p +def _get_4_bin_masks(seed, rand_offsets, p): + rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets) + # binarize masks, save registers + # NOTE: We keep the random numbers as is there (integers over int32), + # and convert the threshold instead, for speed + + # The initial distribution is -2**31 / 2**31 -1 + # and our float threshold is in between [0, 1] + # The full computation is: `start_point + full range * p` + threshold = (-2147483648.0 + 4294967295.0 * p).to(tl.int32) + rand_mask1 = rand1 > threshold + rand_mask2 = rand2 > threshold + rand_mask3 = rand3 > threshold + rand_mask4 = rand4 > threshold + + return rand_mask1, rand_mask2, rand_mask3, rand_mask4 + + +@triton.jit +def _random_prune_and_scale(x, rand_mask, p, p_scale): zero = 0.0 - zero = zero.to(x.dtype) - # prune and normalize in one go - return tl.where(x_keep, (x / (1 - p)).to(x.dtype), zero) + if p > 0.0: + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, x.shape) + + # prune and normalize in one go + x = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype)) + + return x # fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=_configs, + key=["M", "N", "is_fp16"], ) @triton.jit def k_dropout_fw( Y, X, BIAS, SEEDS, stride, - N, + M, N, p, - **META, + is_fp16, # autotune + **meta, ): """ Apply dropout on an input tensor - Y : Output (M, N) - X : Input (M, N) - S : Seeds (M,) + Y : Output (M, N) + X : Input (M, N) + BIAS (N,) + SEEDS (M,) p : dropout probability """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] - # compute memory offsets of elements handled by this instance - offsets = row * stride + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) - # load data from x - x_ptrs = X + offsets - x = tl.load(x_ptrs, mask=mask) + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - x += b + # pointers starting point + x_ptrs = X + rows[:, None] * stride + cols[None, :] + y_ptrs = Y + rows[:, None] * stride + cols[None, :] - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION"]: - x = META["ACTIVATION"](x) + # go over all the tiles, one by one + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4 + rand_mask1, rand_mask2, rand_mask3, rand_mask4 = _get_4_bin_masks(seed, rand_offsets, p) - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, offsets, x) - else: - output = x + col_mask = cols[None, :] < N + p_scale = 1 / (1 - p) if p < 1. else 1. - y_ptrs = Y + offsets - tl.store(y_ptrs, output, mask=mask) + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.) + + for i in range(4): + # cycle through the binary masks (workaround / no indexing) + if i == 0: + rand_mask = rand_mask1 + elif i == 1: + rand_mask = rand_mask2 + elif i == 2: + rand_mask = rand_mask3 + else: + rand_mask = rand_mask4 + + block_mask = (rows[:, None] < M) & col_mask + x = tl.load(x_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if meta["USE_BIAS"]: + x += bias + + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION"]: + x = meta["ACTIVATION"](x) + + # randomly prune (and scale) the resulting buffer, possibly a no-op + output = _random_prune_and_scale(x, rand_mask, p, p_scale) + + tl.store(y_ptrs, output, mask=block_mask) + + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + x_ptrs += BLOCK_M * stride + y_ptrs += BLOCK_M * stride # fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=_configs, + key=["M", "N", "is_fp16"], ) @triton.jit def k_dropout_bw( - GRAD_IN, GRAD_OUT, INPUTS, BIAS, SEEDS, + GRAD_IN, GRAD_BIAS, GRAD_OUT, + INPUTS, BIAS, SEEDS, stride_grad, stride_inputs, - N, + M, N, p, - **META, + is_fp16, # autotune + **meta, ): """ Apply dropout on an input tensor GRAD_OUT (M, N) + GRAD_BIAS (N,) GRAD_IN (M, N) BIAS (N,) - SEEDS (M,) + SEEDS (N,) p : dropout probability """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) - - # compute memory offsets of elements handled by this instance - grad_offsets = row * stride_grad + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N - - # load data from x - grad_out_ptrs = GRAD_OUT + grad_offsets - grad_out = tl.load(grad_out_ptrs, mask=mask) - - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION_GRAD"]: - input_ptrs = INPUTS + row * stride_inputs + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - inputs = tl.load(input_ptrs, mask=mask) - - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - inputs += b - - act_grad = META["ACTIVATION_GRAD"](inputs) - grad_out *= act_grad - - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, grad_offsets, grad_out) - else: - output = grad_out - - # write-back - y_ptrs = GRAD_IN + grad_offsets - tl.store(y_ptrs, output, mask=mask) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] + TRAINABLE_BIAS = meta["TRAINABLE_BIAS"] + + rows = tl.arange(0, BLOCK_M) + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) + + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id # FIXME index the seed properly + + # pointers starting point + grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :] + grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :] + input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :] + + # random binary masks, save registers + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4 + rand_mask1, rand_mask2, rand_mask3, rand_mask4 = _get_4_bin_masks(seed, rand_offsets, p) + + # now go over the tiles + grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) + col_mask = cols[None, :] < N + p_scale = 1 / (1 - p) if p < 1. else 1. + + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=col_mask, other=0.) + + for i in range(4): + # cycle through the binary masks (workaround / no indexing) + if i == 0: + rand_mask = rand_mask1 + elif i == 1: + rand_mask = rand_mask2 + elif i == 2: + rand_mask = rand_mask3 + else: + rand_mask = rand_mask4 + + block_mask = (rows[:, None] < M) & col_mask + grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.) + + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION_GRAD"]: + inputs = tl.load(input_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if meta["USE_BIAS"]: + inputs += bias + + act_grad = meta["ACTIVATION_GRAD"](inputs).to(grad_out.dtype) + grad_out *= act_grad + + # randomly prune (and scale) the resulting buffer, possibly a no-op + # note that even if we did not save the mask from the FW pass, it is generated + # from the same seeds, so the same drop mask is applied here + output = _random_prune_and_scale(grad_out, rand_mask, p, p_scale) + + # write-back + tl.store(grad_in_ptrs, output, mask=block_mask) + + # optionally accumulate the bias gradient + if TRAINABLE_BIAS: + grad_bias += tl.sum(output, axis=0) + + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + grad_out_ptrs += BLOCK_M * stride_grad + input_ptrs += BLOCK_M * stride_inputs + grad_in_ptrs += BLOCK_M * stride_grad + + if TRAINABLE_BIAS: + grad_bias_ptr = GRAD_BIAS + row_id * N + cols + tl.store(grad_bias_ptr, grad_bias, mask=cols < N)