Skip to content

Commit

Permalink
precompute softmax D in non-cuda ring flash attn
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 9, 2024
1 parent 0b5d5af commit 011ec89
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ $ python assert.py --use-cuda --causal --striped-ring-attn
- [x] for cuda striped attention, for backwards hack, pad the extra token once and index out when passing into Tri's cuda kernel
- [x] find a machine with 8 GPUs and test with a quarter million tokens first

- [ ] see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if notattn)
- [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training
- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] figure out how to pytest distributed pytorch
Expand Down
15 changes: 10 additions & 5 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ def backward(ctx, do):
receive_kv_and_dkv = None
receive_mask = None

# precompute the softmax D

D = (do * o).sum(dim = -1, keepdims = True)
D = rearrange(D, 'b n h 1 -> b h n 1')

# ring reduce key / values

for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):
k_ring_rank = ring_rank % ring_size

Expand All @@ -274,13 +281,13 @@ def backward(ctx, do):

row_splits = zip(
q.split(bucket_size, dim = 1),
o.split(bucket_size, dim = 1),
do.split(bucket_size, dim = 1),
D.split(bucket_size, dim = -2),
lse.split(bucket_size, dim = -2),
dq.split(bucket_size, dim = 1)
)

for ind, (qc, oc, doc, lsec, dqc) in enumerate(row_splits):
for ind, (qc, doc, Dc, lsec, dqc) in enumerate(row_splits):
row_bucket_index = row_ring_rank * per_machine_buckets + ind

attn_weights = einsum('b i h d, b j h d -> b h i j', qc, kc) * scale
Expand Down Expand Up @@ -311,9 +318,7 @@ def backward(ctx, do):
dv_chunk = einsum('b h i j, b i h d -> b j h d', p, doc)
dp = einsum('b i h d, b j h d -> b h i j', doc, vc)

D = (doc * oc).sum(dim = -1, keepdims = True)
D = rearrange(D, 'b n h 1 -> b h n 1')
ds = p * scale * (dp - D)
ds = p * scale * (dp - Dc)

dq_chunk = einsum('b h i j, b j h d -> b i h d', ds, kc)
dk_chunk = einsum('b h i j, b i h d -> b j h d', ds, qc)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.7',
version = '0.3.8',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 011ec89

Please sign in to comment.