Skip to content

Commit

Permalink
keep numerator and denominator in float32 for tree attn decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2024
1 parent ed59cee commit d49499f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
18 changes: 10 additions & 8 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def tree_attn_decode(
shard_kv_seq = False,
use_triton = None
):
dtype = q.dtype

assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present'

if exists(k):
Expand All @@ -37,8 +39,6 @@ def tree_attn_decode(
https://arxiv.org/abs/2408.04093
"""

dim_v = v.shape[-1]

# each machine (rank) takes care of a chunk of kv sequence within the world of many machines

if shard_kv_seq:
Expand All @@ -59,7 +59,7 @@ def tree_attn_decode(
if use_triton and q.is_cuda:
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward

out, local_max, lse = flash_attn_forward(
local_out, local_max, lse = flash_attn_forward(
q, k, v,
causal = False,
return_normalized_output = True,
Expand All @@ -77,16 +77,16 @@ def tree_attn_decode(
lse = sim.logsumexp(dim = -1, keepdim = True)

attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)
local_out = einsum('... i j, ... j d -> ... i d', attn, v)

den = lse.exp()
num = out * den
num = local_out.float() * den

else:
# handle edge case where seq length < world size

num = q.new_zeros((*q.shape[:-1], dim_v))
den = q.new_zeros((*q.shape[:-1], 1))
num = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
den = q.new_zeros((*q.shape[:-1], 1), dtype = torch.float32)
local_max = torch.zeros_like(den)

# first get global max through an all reduce (max)
Expand All @@ -106,4 +106,6 @@ def tree_attn_decode(
dist.all_reduce(den)
dist.all_reduce(num)

return num / den.clamp(min = eps)
out = num.div_(den.clamp(min = eps))

return out.type(dtype)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.9',
version = '0.5.10',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d49499f

Please sign in to comment.