diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png index e3b7d96237..f986ced1ca 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png index c4ebfb507c..5e451751ca 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png index 2d7b7a3cf9..c7c119fb8d 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png index 48edecc177..34ede9b2f0 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png index fb483ce0ff..7c841f03f4 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png index 023e5acdfc..5935eb314c 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW+BW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png index 76a8a946b3..93ef40d7ff 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png index 2b46f2b0a9..e02db4f41b 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png index 8f926ad8bb..ea424a1e14 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png index 06d97958b0..4f1baae3d1 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png index 4995a7c0d2..5ec1e20db9 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png index 4505405723..86e87e1d53 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_False_FW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png index 044564e4d2..e12bec88d0 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png index 7a8bcab0c0..1cbd246e14 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png index 6169e7a2e7..77207f15c1 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png index a27b8809d7..07c990ca43 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png index 05f51ac367..f719b544e8 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png index ed9fd311fd..cd7321e65f 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png index 960da502b5..629b5cda5b 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png index b319f08279..b95ac75168 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png index 604d8c91e9..cc52a70023 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png index a5ae8721c3..50a4e89906 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_None.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png index d18ec43b40..9d9eb97732 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_gelu.png differ diff --git a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png index e4a27d2eeb..ec3fd66519 100644 Binary files a/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png and b/docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32_Act:_squared_relu.png differ diff --git a/examples/microGPT.py b/examples/microGPT.py index 685d352f8a..fe664d1c1c 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -73,7 +73,7 @@ def __init__( "feedforward_config": { "name": "FusedMLP", # Use MLP if Triton is not available "dropout": self.hparams.mlp_pdrop, - "activation": "squared_relu", + "activation": "gelu", "hidden_layer_multiplier": self.hparams.hidden_layer_multiplier, }, } @@ -273,7 +273,7 @@ def top_k_logits(logits, k): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory REF_BATCH = 512 - BATCH = 128 + BATCH = 256 WORKERS = 4 EPOCHS = 1 diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 9908fa66e3..853f27fc45 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -19,13 +19,10 @@ get_triton_activation_kernel, ) from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw -from xformers.triton.sum_strided import sum_2d_dim_0 -# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64) -# for the random numbers to be good enough -GROUP_M = 8 +GROUP_M = 32 # 32 BLOCK_M = GROUP_M // 4 -BLOCK_N = 128 +BLOCK_N = 128 # 128 # Helper to handle the SPMD launch grid and error cases @@ -145,7 +142,7 @@ def grid(meta): return ( grad_in.reshape_as(grad_out), None, - sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None, + torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None, None, None, None, diff --git a/xformers/triton/sum_strided.py b/xformers/triton/sum_strided.py index f3f29c3931..ed23384bb4 100644 --- a/xformers/triton/sum_strided.py +++ b/xformers/triton/sum_strided.py @@ -36,10 +36,10 @@ def sum_2d_dim_0(x: torch.Tensor): ) BLOCK_M = min(triton.next_power_of_2(M), 2048) - BLOCK_N = 48 + BLOCK_N = 32 if BLOCK_M > 256: BLOCK_N = 16 - if BLOCK_M >= 1024: + if BLOCK_M > 1024: BLOCK_N = 8 def grid(meta): @@ -53,6 +53,7 @@ def grid(meta): x.dtype == torch.float16, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + num_stages=4, ) # fmt: on