diff --git a/examples/microGPT.py b/examples/microGPT.py index fe664d1c1c..685d352f8a 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -73,7 +73,7 @@ def __init__( "feedforward_config": { "name": "FusedMLP", # Use MLP if Triton is not available "dropout": self.hparams.mlp_pdrop, - "activation": "gelu", + "activation": "squared_relu", "hidden_layer_multiplier": self.hparams.hidden_layer_multiplier, }, } @@ -273,7 +273,7 @@ def top_k_logits(logits, k): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory REF_BATCH = 512 - BATCH = 256 + BATCH = 128 WORKERS = 4 EPOCHS = 1 diff --git a/xformers/benchmarks/benchmark_triton_dropout.py b/xformers/benchmarks/benchmark_triton_dropout.py index d13e3fe81c..e7e3bf738b 100644 --- a/xformers/benchmarks/benchmark_triton_dropout.py +++ b/xformers/benchmarks/benchmark_triton_dropout.py @@ -87,7 +87,9 @@ def triton_step(x): ), ), ]: - time = triton.testing.do_bench(lambda: testcase.function(a))[0] + time = triton.testing.do_bench( + lambda: testcase.function(a), grad_to_none=[a, b] + )[0] key = f"B={B}, M={M}, K={K}" if key not in results: results[key] = {} diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 64c12f6b84..9908fa66e3 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -23,9 +23,9 @@ # NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64) # for the random numbers to be good enough -GROUP_M = 16 +GROUP_M = 8 BLOCK_M = GROUP_M // 4 -BLOCK_N = 64 +BLOCK_N = 128 # Helper to handle the SPMD launch grid and error cases @@ -60,11 +60,11 @@ def grid(meta): y.stride(0), M, N, p, + x.dtype == torch.float16, USE_BIAS=bias is not None, ACTIVATION=activation, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=2 ) # fmt: on @@ -133,12 +133,12 @@ def grid(meta): grad_out_.stride(0), inputs.stride(0), M, N, ctx.p, + grad_in.dtype == torch.float16, USE_BIAS=bias is not None, ACTIVATION_GRAD=ctx.activation_grad, TRAINABLE_BIAS=ctx.trainable_bias, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=2 ) # fmt: on diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index c43c61d666..f37511491e 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -10,15 +10,28 @@ import triton import triton.language as tl +_configs = [ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), +] + # fmt: off @triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) +@triton.autotune( + configs=_configs, + key=["M", "N", "is_fp16"], +) @triton.jit def k_dropout_fw( Y, X, BIAS, SEEDS, stride, M, N, p, + is_fp16, # autotune **meta, ): """ @@ -108,6 +121,10 @@ def k_dropout_fw( # fmt: off @triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) +@triton.autotune( + configs=_configs, + key=["M", "N", "is_fp16"], +) @triton.jit def k_dropout_bw( GRAD_IN, GRAD_BIAS, GRAD_OUT, @@ -115,6 +132,7 @@ def k_dropout_bw( stride_grad, stride_inputs, M, N, p, + is_fp16, # autotune **meta, ): """ diff --git a/xformers/triton/sum_strided.py b/xformers/triton/sum_strided.py index ed23384bb4..f3f29c3931 100644 --- a/xformers/triton/sum_strided.py +++ b/xformers/triton/sum_strided.py @@ -36,10 +36,10 @@ def sum_2d_dim_0(x: torch.Tensor): ) BLOCK_M = min(triton.next_power_of_2(M), 2048) - BLOCK_N = 32 + BLOCK_N = 48 if BLOCK_M > 256: BLOCK_N = 16 - if BLOCK_M > 1024: + if BLOCK_M >= 1024: BLOCK_N = 8 def grid(meta): @@ -53,7 +53,6 @@ def grid(meta): x.dtype == torch.float16, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_stages=4, ) # fmt: on