Skip to content

Commit

Permalink
larger tiles and better speed, full loss curve checked
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 27, 2021
1 parent b36e8ca commit 8b9befa
Show file tree
Hide file tree
Showing 27 changed files with 8 additions and 10 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
"feedforward_config": {
"name": "FusedMLP", # Use MLP if Triton is not available
"dropout": self.hparams.mlp_pdrop,
"activation": "squared_relu",
"activation": "gelu",
"hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
},
}
Expand Down Expand Up @@ -273,7 +273,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 128
BATCH = 256

WORKERS = 4
EPOCHS = 1
Expand Down
9 changes: 3 additions & 6 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
get_triton_activation_kernel,
)
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
from xformers.triton.sum_strided import sum_2d_dim_0

# NOTE: GROUP_M and BLOCK_N need to be kept low (<16x64)
# for the random numbers to be good enough
GROUP_M = 8
GROUP_M = 32 # 32
BLOCK_M = GROUP_M // 4
BLOCK_N = 128
BLOCK_N = 128 # 128


# Helper to handle the SPMD launch grid and error cases
Expand Down Expand Up @@ -145,7 +142,7 @@ def grid(meta):
return (
grad_in.reshape_as(grad_out),
None,
sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None,
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
None,
None,
None,
Expand Down
5 changes: 3 additions & 2 deletions xformers/triton/sum_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def sum_2d_dim_0(x: torch.Tensor):
)

BLOCK_M = min(triton.next_power_of_2(M), 2048)
BLOCK_N = 48
BLOCK_N = 32
if BLOCK_M > 256:
BLOCK_N = 16
if BLOCK_M >= 1024:
if BLOCK_M > 1024:
BLOCK_N = 8

def grid(meta):
Expand All @@ -53,6 +53,7 @@ def grid(meta):
x.dtype == torch.float16,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=4,
)
# fmt: on

Expand Down

0 comments on commit 8b9befa

Please sign in to comment.