Skip to content

Commit

Permalink
WIP better scheduling, improving perfs all around
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 26, 2021
1 parent 9ee09ee commit b36e8ca
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
"feedforward_config": {
"name": "FusedMLP", # Use MLP if Triton is not available
"dropout": self.hparams.mlp_pdrop,
"activation": "gelu",
"activation": "squared_relu",
"hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
},
}
Expand Down Expand Up @@ -273,7 +273,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256
BATCH = 128

WORKERS = 4
EPOCHS = 1
Expand Down
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def triton_step(x):
),
),
]:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
time = triton.testing.do_bench(
lambda: testcase.function(a), grad_to_none=[a, b]
)[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
Expand Down
8 changes: 4 additions & 4 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64)
# for the random numbers to be good enough
GROUP_M = 16
GROUP_M = 8
BLOCK_M = GROUP_M // 4
BLOCK_N = 64
BLOCK_N = 128


# Helper to handle the SPMD launch grid and error cases
Expand Down Expand Up @@ -60,11 +60,11 @@ def grid(meta):
y.stride(0),
M, N,
p,
x.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=2
)
# fmt: on

Expand Down Expand Up @@ -133,12 +133,12 @@ def grid(meta):
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
grad_in.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=2
)
# fmt: on

Expand Down
18 changes: 18 additions & 0 deletions xformers/triton/k_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,28 @@
import triton
import triton.language as tl

_configs = [
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
]


# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_fw(
Y, X, BIAS, SEEDS,
stride,
M, N,
p,
is_fp16, # autotune
**meta,
):
"""
Expand Down Expand Up @@ -108,13 +121,18 @@ def k_dropout_fw(

# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_bw(
GRAD_IN, GRAD_BIAS, GRAD_OUT,
INPUTS, BIAS, SEEDS,
stride_grad, stride_inputs,
M, N,
p,
is_fp16, # autotune
**meta,
):
"""
Expand Down
5 changes: 2 additions & 3 deletions xformers/triton/sum_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def sum_2d_dim_0(x: torch.Tensor):
)

BLOCK_M = min(triton.next_power_of_2(M), 2048)
BLOCK_N = 32
BLOCK_N = 48
if BLOCK_M > 256:
BLOCK_N = 16
if BLOCK_M > 1024:
if BLOCK_M >= 1024:
BLOCK_N = 8

def grid(meta):
Expand All @@ -53,7 +53,6 @@ def grid(meta):
x.dtype == torch.float16,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=4,
)
# fmt: on

Expand Down

0 comments on commit b36e8ca

Please sign in to comment.