diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index 499dfe7..ce71480 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -557,6 +557,10 @@ def forward( receive_kv = None receive_mask = None + # non-causal and causal striped attention can have final normalization of output fused + + can_fuse_final_output_normalization = not causal or (causal and striped_ring_attn) + for (ring_rank, (is_first, is_last)), ((kv, mask), (receive_kv, receive_mask)) in ring_pass_fn(kv, mask, receive_buffers = (receive_kv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size): k, v = kv @@ -593,15 +597,17 @@ def forward( bias = bias, softmax_scale = softmax_scale, causal_mask_diagonal = causal_mask_diagonal, - return_normalized_output = False, + return_normalized_output = can_fuse_final_output_normalization and is_last, load_accumulated = not is_first ) lse = lse[..., :q_seq_len] - m = m[..., :q_seq_len] - o_scale = torch.exp(m - lse) - o.mul_(rearrange(o_scale, 'b h n -> b n h 1')) + if not can_fuse_final_output_normalization: + m = m[..., :q_seq_len] + + o_scale = torch.exp(m - lse) + o.mul_(rearrange(o_scale, 'b h n -> b n h 1')) ctx.args = ( causal, diff --git a/setup.py b/setup.py index d5de77d..c347516 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.5', + version = '0.3.6', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',