diff --git a/CHANGELOG.md b/CHANGELOG.md index 66c2cf20c..e82850914 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## TBD +### Fixed +- Much faster fused dropout [#164] + ## [0.0.7] - 2021-11-30 ### Fixed - Dropout setting not properly passed in many attentions [#123] diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png index 088ed1688..c4ebfb507 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png index f818ff455..fb483ce0f 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png index 82f6c035c..2b46f2b0a 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png index 022e73330..4995a7c0d 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png index a23fe0baf..044564e4d 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png index 8f466eb95..7a8bcab0c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png index 16f9e9097..a27b8809d 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png index dd90d79c8..05f51ac36 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png index ae7c4d533..b319f0827 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png index b84c86019..d18ec43b4 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png differ diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 108145e57..e2e5e0504 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -44,6 +44,15 @@ def test_dropout_cpu(): x = torch.normal(0, 1, size=(16, 16), device="cpu") _ = triton_dropout(x) + # Check eval means no dropout + triton_dropout.eval() + y = triton_dropout(x) + assert y.count_nonzero() == y.numel() + + triton_dropout.train() + y = triton_dropout(x) + assert y.count_nonzero() != y.numel() + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 4c8571de9..64c12f6b8 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -21,9 +21,11 @@ from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw from xformers.triton.sum_strided import sum_2d_dim_0 +# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64) +# for the random numbers to be good enough GROUP_M = 16 BLOCK_M = GROUP_M // 4 -BLOCK_N = 128 +BLOCK_N = 64 # Helper to handle the SPMD launch grid and error cases @@ -61,7 +63,8 @@ def grid(meta): USE_BIAS=bias is not None, ACTIVATION=activation, BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N + BLOCK_N=BLOCK_N, + num_warps=2 ) # fmt: on @@ -135,7 +138,7 @@ def grid(meta): TRAINABLE_BIAS=ctx.trainable_bias, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=8 + num_warps=2 ) # fmt: on @@ -200,6 +203,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.bias is not None: # type: ignore self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore + # Train/inference + p = self.p if self.training else 0.0 + # This kernel is slower than pytorch for small buffers, bypassing it in that case perf_check = x.shape[-1] > 512 @@ -207,10 +213,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not x.is_cuda or not perf_check: x = x + self.bias if self.bias is not None else x x = self.pytorch_activation(x) - return torch.nn.functional.dropout(x, self.p) + return torch.nn.functional.dropout(x, p) # The normal, Triton-backed path - p = self.p if self.training else 0.0 return _dropout.apply( x, p, self.bias, self.activation, self.activation_grad, True )