Skip to content

Commit

Permalink
final micro optimization to ring flash attn cuda forward
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 7, 2024
1 parent 74d2e7b commit 228d824
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@ def forward(
receive_kv = None
receive_mask = None

# non-causal and causal striped attention can have final normalization of output fused

can_fuse_final_output_normalization = not causal or (causal and striped_ring_attn)

for (ring_rank, (is_first, is_last)), ((kv, mask), (receive_kv, receive_mask)) in ring_pass_fn(kv, mask, receive_buffers = (receive_kv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):
k, v = kv

Expand Down Expand Up @@ -593,15 +597,17 @@ def forward(
bias = bias,
softmax_scale = softmax_scale,
causal_mask_diagonal = causal_mask_diagonal,
return_normalized_output = False,
return_normalized_output = can_fuse_final_output_normalization and is_last,
load_accumulated = not is_first
)

lse = lse[..., :q_seq_len]
m = m[..., :q_seq_len]

o_scale = torch.exp(m - lse)
o.mul_(rearrange(o_scale, 'b h n -> b n h 1'))
if not can_fuse_final_output_normalization:
m = m[..., :q_seq_len]

o_scale = torch.exp(m - lse)
o.mul_(rearrange(o_scale, 'b h n -> b n h 1'))

ctx.args = (
causal,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.5',
version = '0.3.6',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 228d824

Please sign in to comment.