Skip to content

Commit

Permalink
updated algorithm 3 in tree attn decoding paper is more concise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2024
1 parent d49499f commit 70fb3c6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
23 changes: 8 additions & 15 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def tree_attn_decode(
if use_triton and q.is_cuda:
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward

local_out, local_max, lse = flash_attn_forward(
local_out, _, lse = flash_attn_forward(
q, k, v,
causal = False,
return_normalized_output = True,
Expand All @@ -72,34 +72,27 @@ def tree_attn_decode(
scale = q.shape[-1] ** -0.5
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale

local_max = sim.amax(dim = -1, keepdim = True)
sim -= local_max
sim -= sim.amax(dim = -1, keepdim = True)
lse = sim.logsumexp(dim = -1, keepdim = True)

attn = sim.softmax(dim = -1)
local_out = einsum('... i j, ... j d -> ... i d', attn, v)

den = lse.exp()
num = local_out.float() * den

else:
# handle edge case where seq length < world size

num = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
den = q.new_zeros((*q.shape[:-1], 1), dtype = torch.float32)
local_max = torch.zeros_like(den)
local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
lse = torch.full_like(den, -torch.finfo(torch.float32).max)

# first get global max through an all reduce (max)

global_max = local_max.clone()
dist.all_reduce(global_max, dist.ReduceOp.MAX)
global_lse = lse.clone()
dist.all_reduce(global_lse, dist.ReduceOp.MAX)

# renormalize the numerator and denominators

renorm_factor = (local_max - global_max).exp()

den *= renorm_factor
num *= renorm_factor
den = (lse - global_lse).exp()
num = local_out * den

# second and third all reduce (sum)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.10',
version = '0.5.11',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 70fb3c6

Please sign in to comment.