Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Dropout partial bw fusion (second take) #164

Merged
merged 7 commits into from
Jan 3, 2022
Merged
23 changes: 15 additions & 8 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

![Fused linear layers throughput in fp16 - training](docs/plots/fused_linea/FusedLinear_fp16_FW_BW_gelu.png)
![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png)

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png)

Expand All @@ -74,7 +74,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_none.png)

![Fused linear layers throughput in fp16 - training](docs/plots/fused_line/FusedLinear_fp16_FW_BW_none.png)
![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png)

### Fused layer norm

Expand All @@ -89,18 +89,25 @@ Note that in the Triton case the slowdowns at extreme sizes are because of regis

![Fused layer norm throughput in fp32 - training](docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png))

### Fused dropout + bias
### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16.png)
![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_gelu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16.png))
![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32.png))
![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_gelu.png))

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act:_squared_relu.png)

![Fused dropout+ bias throughput in fp16 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float16_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

![Fused dropout+ bias throughput in fp32 - training](docs/plots/fused_dropout/Dropout_Bias_True_FW+BW_torch.float32_Act:_squared_relu.png))

## LRA

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Fixed
- Much faster fused dropout [#164]

## [0.0.7] - 2021-11-30
### Fixed
- Dropout setting not properly passed in many attentions [#123]
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
3. Native support for SquaredReLU (on top of ReLU, LeakyReLU, GeLU, ..), extensible activations

### FAQ ?

Expand Down
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Diff not rendered.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def top_k_logits(logits, k):
gpus=1,
max_epochs=EPOCHS,
precision=16,
gradient_clip_val=1,
gradient_clip_val=1, # Use to catch divergent gradients, if experimenting
log_every_n_steps=1,
detect_anomaly=True,
# detect_anomaly=True, # Use to catch NaNs, if experimenting
accumulate_grad_batches=REF_BATCH // BATCH,
)

Expand Down
21 changes: 18 additions & 3 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_dropout_cpu():
x = torch.normal(0, 1, size=(16, 16), device="cpu")
_ = triton_dropout(x)

# Check eval means no dropout
triton_dropout.eval()
y = triton_dropout(x)
assert y.count_nonzero() == y.numel()

triton_dropout.train()
y = triton_dropout(x)
assert y.count_nonzero() != y.numel()


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand All @@ -53,7 +62,8 @@ def test_dropout_cpu():
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("p", [0, 0.1, 0.5])
def test_dropout(shape, amp, bias, p):
"""
Check some basic dropout properties
"""
Expand Down Expand Up @@ -97,6 +107,11 @@ def test_dropout(shape, amp, bias):
== y.shape[1]
)

# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.01


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand All @@ -107,7 +122,7 @@ def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("p", [0, 0.001, 0.5])
@pytest.mark.parametrize("p", [0, 0.01, 0.5])
def test_dropout_parity(shape, amp, bias, activation, p):
"""
Check some basic dropout properties
Expand Down Expand Up @@ -158,4 +173,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad.norm()}\n{b_.grad.norm()}"
), f"{b.grad.norm()} - {b_.grad.norm()}"
8 changes: 6 additions & 2 deletions xformers/benchmarks/benchmark_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def torch_step(x):
y = torch_act(y)

if backward:
y.grad = None
torch.norm(y).backward()
return y

def triton_step(x):
y = triton_dropout(x)
if backward:
y.grad = None
torch.norm(y).backward()
return y

Expand All @@ -85,7 +87,9 @@ def triton_step(x):
),
),
]:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
time = triton.testing.do_bench(
lambda: testcase.function(a), grad_to_none=[a, b]
)[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
Expand All @@ -105,7 +109,7 @@ def triton_step(x):
)


for activation in [Activation.GeLU, None]:
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)
10 changes: 5 additions & 5 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<40}".format(units)
print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))

offset = len(header)
print(
"|{}|".format("-" * offset)
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)

Expand All @@ -44,7 +44,7 @@ def pretty_print(results, title, units):

for k, w in workloads.items():
print(
"|{0:<{offset}}|".format(k, offset=offset)
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.xticks(rotation=45)

plt.savefig(filename, bbox_inches="tight")
plt.clf()
plt.close(f)


if _triton_is_available:
Expand Down
Loading