Skip to content

Commit

Permalink
successful tests on 8 A40s on runpod.io for quarter million tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2024
1 parent d786f3f commit 423279f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ $ python assert.py --use-cuda --causal --striped-ring-attn
- [x] validate cuda causal and striped ring attention works
- [x] make sure cuda striped attention works for multiple buckets, otherwise flash attention is ineffective
- [x] for cuda striped attention, for backwards hack, pad the extra token once and index out when passing into Tri's cuda kernel
- [x] find a machine with 8 GPUs and test with a quarter million tokens first

- [ ] find a machine with 8 GPUs and test with a quarter million tokens first
- [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training
- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] `batch_isend_irecv` in the presence of key padding mask needing ring exchange, but not a big priority
- [ ] figure out how to pytest distributed pytorch
- [ ] use sdp context manager to validate when it is possible to use `ring_flash_attn_cuda`, otherwise assert out

Expand Down
1 change: 1 addition & 0 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def start(
depth = 2,
dim_head = 8,
ring_attn = False,
ring_seq_size = ring_seq_size,
bucket_size = bucket_size,
use_cuda_kernel = False
)
Expand Down
5 changes: 3 additions & 2 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
self.scale = dim_head ** -0.5
self.causal = causal

assert divisible_by(ring_seq_size, bucket_size)
assert (not ring_attn) or divisible_by(ring_seq_size, bucket_size), f'ring seq size {ring_seq_size} is not divisible by bucket size {bucket_size}'

self.ring_attn = ring_attn
self.max_lookback_seq_len = max_lookback_seq_len
Expand Down Expand Up @@ -468,7 +468,8 @@ def __init__(

self.ring_seq_size = ring_seq_size
self.bucket_size = bucket_size
assert divisible_by(ring_seq_size, bucket_size)

assert (not ring_attn) or divisible_by(ring_seq_size, bucket_size), f'ring seq size {ring_seq_size} is not divisible by bucket size {bucket_size}'

self.auto_shard_seq = default(auto_shard_seq, ring_attn) # if ring attention is turned on, auto-shard across sequence dimension. this can also be turned off and done manually elsewhere in the data loading

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.4',
version = '0.3.5',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 423279f

Please sign in to comment.