Skip to content

Commit

Permalink
flipping the seeds so that it drops down from the top
Browse files Browse the repository at this point in the history
using less seeds

tiling + vertical seeds

Computing the FW and BW per tile over M

better scheduling defaults, improves across the board

good enough perfs

catching the slow case and diverting to pytorch in that case
  • Loading branch information
blefaudeux committed Dec 23, 2021
1 parent 5148844 commit b0d5f91
Show file tree
Hide file tree
Showing 33 changed files with 297 additions and 156 deletions.
19 changes: 13 additions & 6 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,25 @@ Note that in the Triton case the slowdowns at extreme sizes are because of regis

![Fused layer norm throughput in fp32 - training](docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png))

### Fused dropout + bias
### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png)
![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png))
![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

## LRA

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
3. Native support for SquaredReLU (on top of ReLU, LeakyReLU, GeLU, ..), extensible activations
### FAQ ?
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def top_k_logits(logits, k):
gpus=1,
max_epochs=EPOCHS,
precision=16,
gradient_clip_val=1,
gradient_clip_val=1, # Use to catch divergent gradients, if experimenting
log_every_n_steps=1,
detect_anomaly=True,
# detect_anomaly=True, # Use to catch NaNs, if experimenting
accumulate_grad_batches=REF_BATCH // BATCH,
)

Expand Down
12 changes: 9 additions & 3 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_dropout_cpu():
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("p", [0, 0.1, 0.5])
def test_dropout(shape, amp, bias, p):
"""
Check some basic dropout properties
"""
Expand Down Expand Up @@ -97,6 +98,11 @@ def test_dropout(shape, amp, bias):
== y.shape[1]
)

# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.1


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand All @@ -107,7 +113,7 @@ def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("p", [0, 0.001, 0.5])
@pytest.mark.parametrize("p", [0, 0.01, 0.5])
def test_dropout_parity(shape, amp, bias, activation, p):
"""
Check some basic dropout properties
Expand Down Expand Up @@ -158,4 +164,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad.norm()}\n{b_.grad.norm()}"
), f"{b.grad.norm()} - {b_.grad.norm()}"
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def torch_step(x):
y = torch_act(y)

if backward:
y.grad = None
torch.norm(y).backward()
return y

def triton_step(x):
y = triton_dropout(x)
if backward:
y.grad = None
torch.norm(y).backward()
return y

Expand Down Expand Up @@ -105,7 +107,7 @@ def triton_step(x):
)


for activation in [Activation.GeLU, None]:
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)
10 changes: 5 additions & 5 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<40}".format(units)
print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))

offset = len(header)
print(
"|{}|".format("-" * offset)
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)

Expand All @@ -44,7 +44,7 @@ def pretty_print(results, title, units):

for k, w in workloads.items():
print(
"|{0:<{offset}}|".format(k, offset=offset)
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.xticks(rotation=45)

plt.savefig(filename, bbox_inches="tight")
plt.clf()
plt.close(f)


if _triton_is_available:
Expand Down
122 changes: 87 additions & 35 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,47 +21,56 @@
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
from xformers.triton.sum_strided import sum_2d_dim_0

GROUP_M = 16
BLOCK_M = GROUP_M // 4
BLOCK_N = 128


# Helper to handle the SPMD launch grid and error cases
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad):
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
_, N = x_.shape

assert bias is None or bias.dtype == x.dtype, bias
M, N = x_.shape

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32)
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)

# SPMD launch grid
def grid(meta):
return (
x_.shape[0],
triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

N_BLOCK_N = triton.cdiv(N, BLOCK_N)

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32)

# fmt: off
k_dropout_fw[grid](
y, x_, bias if bias is not None else x_,
y, x_,
bias if bias is not None else x_,
seeds,
y.stride(0),
N,
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N
)
# fmt: on

if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None

ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.p = p

Expand All @@ -76,40 +85,68 @@ def backward(ctx, grad_out):
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)

_, N = grad_out_.shape
M, N = grad_out_.shape

# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None

if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, grad_out.shape[-1])
inputs = inputs.reshape(-1, N)

# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over M, we go by 4 tiles at at time (consequence of the random number generation)
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, GROUP_M)

if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)

else:
grad_bias = grad_in # will not be used

# SPMD launch grid
def grid(meta):
return (
grad_out_.shape[0],
triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

# fmt: off
k_dropout_bw[grid](
grad_in, grad_out_, inputs, bias if bias is not None else inputs,
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
N,
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad)
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=8
)
# fmt: on

if ctx.trainable_bias:
grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in)
else:
grad_bias = None

return grad_in.reshape_as(grad_out), None, grad_bias, None, None
return (
grad_in.reshape_as(grad_out),
None,
sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None,
None,
None,
None,
)


def dropout(
Expand All @@ -129,7 +166,14 @@ def dropout(

act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel)
return _dropout.apply(
x,
p,
bias,
act_kernel,
act_grad_kernel,
bias is not None and bias.requires_grad,
)


class FusedDropoutBias(torch.nn.Module):
Expand All @@ -142,23 +186,31 @@ def __init__(
super().__init__()
self.p = p
self.activation_type = activation
self.register_buffer(
"bias", torch.zeros(bias_shape) if bias_shape is not None else None
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_kernel(activation)
self.pytorch_activation = build_activation(self.activation_type)
self.activation_grad = get_triton_activation_bwd_kernel(activation)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
if self.bias is not None: # type: ignore
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512

# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda:
activation = build_activation(self.activation_type)
if not x.is_cuda or not perf_check:
x = x + self.bias if self.bias is not None else x
x = activation(x)
x = self.pytorch_activation(x)
return torch.nn.functional.dropout(x, self.p)

# The normal, Triton-backed path
p = self.p if self.training else 0.0
return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad)
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)
9 changes: 3 additions & 6 deletions xformers/triton/k_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def relu(x):
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero = 0.0
zero = zero.to(x.dtype)
return tl.where(x >= 0, x, zero)
return tl.where(x >= 0, x, zero.to(x.dtype))


@triton.jit
Expand All @@ -74,10 +73,8 @@ def relu_grad(x):
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
zero = zero.to(x.dtype)
one = 1.0
one = one.to(x.dtype)
return tl.where(x >= 0, one, zero)
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))


@triton.jit
Expand All @@ -88,7 +85,7 @@ def squared_relu(x):
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return x_ * x_
return (x_ * x_).to(x.dtype)


@triton.jit
Expand Down
Loading

0 comments on commit b0d5f91

Please sign in to comment.