Skip to content

Commit

Permalink
einx does not play well with torch.compile yet
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2024
1 parent 8225bfc commit 6dc78bc
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ring_attention_pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def all_gather_variable_dim(t, dim = 0, sizes = None):
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = einx.less('j, i -> (i j)', seq, sizes)
mask = seq[None, :] < sizes[:, None]
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

Expand Down
40 changes: 20 additions & 20 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from torch.cuda.amp import autocast
from torch.nn import Module, ModuleList

import einx
from einx import rearrange
from einops import rearrange, repeat

from beartype import beartype

Expand Down Expand Up @@ -65,11 +64,12 @@ def default_attention(
sim = torch.where(causal_mask, mask_value, sim)

elif exists(mask):
sim = einx.where('b j, b h i j, -> b h i j', mask, sim, mask_value)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)

# attend

attn = einx.softmax('b h i [j]', sim)
attn = sim.softmax(dim = -1)

# aggregate

Expand Down Expand Up @@ -119,11 +119,11 @@ def forward(
ring_stride = get_world_size() * buckets

pos = torch.arange(seq_len // buckets, device = device)
pos = rearrange('n -> n b', pos, b = buckets)
pos = repeat(pos, 'n -> n b', b = buckets)

pos = pos * ring_stride
pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
pos = rearrange('n b -> (b n)', pos)
pos = rearrange(pos, 'n b -> (b n)')

else:
pos = torch.arange(seq_len, device = device)
Expand All @@ -132,7 +132,7 @@ def forward(
pos = torch.arange(seq_len, device = device)

pos = pos.type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', pos, self.inv_freq)
freqs = einsum('i , j -> i j', pos, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)

def rotate_half(x):
Expand All @@ -141,7 +141,7 @@ def rotate_half(x):

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
pos = rearrange('n d -> n 1 d', pos)
pos = rearrange(pos, 'n d -> n 1 d')
return t * pos.cos() + rotate_half(t) * pos.sin()

# batch to sequence sharding and back
Expand Down Expand Up @@ -207,7 +207,7 @@ def sharded_batch_to_sharded_seq(

num_sharded_batches = world_size // total_split_seq

x = rearrange('(b s) n -> b (s n)', x, s = num_sharded_batches)
x = rearrange(x, '(b s) n -> b (s n)', s = num_sharded_batches)

# then split sequence across machines

Expand All @@ -216,7 +216,7 @@ def sharded_batch_to_sharded_seq(
x, _ = split_by_rank(x)

if exists(mask):
mask = rearrange('(b s) n -> b (s n)', mask, s = num_sharded_batches)
mask = rearrange(mask, '(b s) n -> b (s n)', s = num_sharded_batches)
mask = mask.split(seq_size, dim = -1)
mask, _ = split_by_rank(mask)

Expand All @@ -231,7 +231,7 @@ def sharded_seq_to_sharded_batch(

logits, _ = all_gather(logits)

logits = rearrange('b (s n) c -> (b s) n c', logits, s = num_sharded_batches)
logits = rearrange(logits, 'b (s n) c -> (b s) n c', s = num_sharded_batches)

logits = logits.split(sizes.tolist(), dim = 0)

Expand Down Expand Up @@ -344,17 +344,17 @@ def forward(
x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)

if self.striped_ring_attn:
x = rearrange('b (i j) d -> b (j i) d', x, i = striped_bucket_size)
x = rearrange(x, 'b (i j) d -> b (j i) d', i = striped_bucket_size)

if exists(mask):
mask = rearrange('b (i j) -> b (j i)', mask, i = striped_bucket_size)
mask = rearrange(mask, 'b (i j) -> b (j i)', i = striped_bucket_size)

(x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)

device = x.device

qkv = self.to_qkv(x)
q, k, v = rearrange('b n (qkv h d) -> qkv b n h d', qkv, qkv = 3, h = self.heads)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b n h d', qkv = 3, h = self.heads)

# rotary relative positions

Expand Down Expand Up @@ -400,7 +400,7 @@ def forward(

# combine heads

out = rearrange('b n h d -> b n (h d)', out)
out = rearrange(out, 'b n h d -> b n (h d)')
out = self.to_out(out)

if auto_shard_seq:
Expand Down Expand Up @@ -561,13 +561,13 @@ def forward(
# for workload balancing https://arxiv.org/abs/2311.09431 - MIT paper from Brandon et al.

if self.striped_ring_attn:
x = rearrange('b (i j) -> b (j i)', x, i = striped_bucket_size)
x = rearrange(x, 'b (i j) -> b (j i)', i = striped_bucket_size)

if exists(labels):
labels = rearrange('b (i j) -> b (j i)', labels, i = striped_bucket_size)
labels = rearrange(labels, 'b (i j) -> b (j i)', i = striped_bucket_size)

if exists(mask):
mask = rearrange('b (i j) -> b (j i)', mask, i = striped_bucket_size)
mask = rearrange(mask, 'b (i j) -> b (j i)', i = striped_bucket_size)

# gather across batch and divide across world

Expand Down Expand Up @@ -605,7 +605,7 @@ def forward(
# handle returning of loss

if return_loss:
logits = rearrange('b n c -> b c n', logits)
logits = rearrange(logits, 'b n c -> b c n')

ce_loss = F.cross_entropy(
logits,
Expand All @@ -623,6 +623,6 @@ def forward(
logits = sharded_seq_to_sharded_batch(logits, batch_sizes, num_sharded_batches)

if self.striped_ring_attn:
logits = rearrange('b (j i) d -> b (i j) d', logits, i = striped_bucket_size)
logits = rearrange(logits, 'b (j i) d -> b (i j) d', i = striped_bucket_size)

return logits[:, :seq_len]
17 changes: 9 additions & 8 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from torch import nn, einsum, Tensor
from torch.autograd.function import Function

import einx
from einx import rearrange
from einops import rearrange

from ring_attention_pytorch.ring import (
ring_pass,
Expand Down Expand Up @@ -149,7 +148,8 @@ def forward(
attn_weights = einsum('b i h d, b j h d -> b h i j', qc, kc) * scale

if exists(col_mask):
attn_weights = einx.where('b j, b h i j, -> b h i j', col_mask, attn_weights, max_neg_value)
col_mask_unsqueezed = rearrange(col_mask, 'b j -> b 1 1 j')
attn_weights = attn_weights.masked_fill(~col_mask_unsqueezed, max_neg_value)

if causal:
qk_len_diff = kc.shape[-3] - qc.shape[-3]
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(
exp_weights = torch.exp(attn_weights - new_row_maxes)

if exists(col_mask):
exp_weights = einx.where('b j, b h i j, -> b h i j', col_mask, exp_weights, 0.)
exp_weights = exp_weights.masked_fill(~col_mask_unsqueezed, 0.)

block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

Expand All @@ -187,13 +187,13 @@ def forward(

new_row_sums = exp_row_max_diff * row_sums + block_row_sums

exp_row_max_diff = rearrange('b h n 1 -> b n h 1', exp_row_max_diff)
exp_row_max_diff = rearrange(exp_row_max_diff, 'b h n 1 -> b n h 1')
oc.mul_(exp_row_max_diff).add_(exp_values)

row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)

o.div_(rearrange('b h n 1 -> b n h 1', all_row_sums))
o.div_(rearrange(all_row_sums, 'b h n 1 -> b n h 1'))

lse = all_row_sums.clamp(min = EPSILON).log() + all_row_maxes

Expand Down Expand Up @@ -305,13 +305,14 @@ def backward(ctx, do):
p = torch.exp(attn_weights - lsec)

if exists(col_mask):
p = einx.where('b j, b h i j, -> b h i j', col_mask, p, 0.)
col_mask_unsqueezed = rearrange(col_mask, 'b j -> b 1 1 j')
p = p.masked_fill(~col_mask_unsqueezed, 0.)

dv_chunk = einsum('b h i j, b i h d -> b j h d', p, doc)
dp = einsum('b i h d, b j h d -> b h i j', doc, vc)

D = (doc * oc).sum(dim = -1, keepdims = True)
D = rearrange('b n h 1 -> b h n 1', D)
D = rearrange(D, 'b n h 1 -> b h n 1')
ds = p * scale * (dp - D)

dq_chunk = einsum('b h i j, b j h d -> b i h d', ds, kc)
Expand Down
8 changes: 4 additions & 4 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from beartype import beartype

from einx import rearrange
from einops import rearrange, repeat

# helpers

Expand Down Expand Up @@ -387,7 +387,7 @@ def flash_attn_forward(
assert bias.is_cuda

if bias.ndim == 2:
bias = rearrange('b j -> b h i j', bias, h = nheads, i = seqlen_q)
bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q)

if not is_contiguous(bias):
bias = bias.contiguous()
Expand Down Expand Up @@ -606,7 +606,7 @@ def forward(
m = m[..., :q_seq_len]

o_scale = torch.exp(m - lse)
o.mul_(rearrange('b h n -> b n h 1', o_scale))
o.mul_(rearrange(o_scale, 'b h n -> b n h 1'))

ctx.args = (
causal,
Expand Down Expand Up @@ -698,7 +698,7 @@ def backward(ctx, do):
# prepare row related tensors with unpad_input

if not causal and exists(mask):
lse = rearrange('b h n ... -> b n h ...', lse)
lse = rearrange(lse, 'b h n ... -> b n h ...')

(
(q, o, do, lse),
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.2',
version = '0.3.3',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand All @@ -17,7 +17,7 @@
],
install_requires=[
'beartype',
'einx[torch]>=0.1.3',
'einops>=0.7.0',
'torch>=2.0'
],
classifiers=[
Expand Down

0 comments on commit 6dc78bc

Please sign in to comment.