Skip to content

Commit

Permalink
fix ring attention
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed Jul 10, 2024
1 parent e681019 commit 4cd6dbc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 85 deletions.
127 changes: 47 additions & 80 deletions paddlenlp/transformers/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@ def wait(self):

def add_to_buffers(self, key, value):
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
k_buffer_chunk = paddle.slice(
self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]]
)
v_buffer_chunk = paddle.slice(
self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]]
)
k_buffer_chunk += key
v_buffer_chunk += value
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(key)
self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(value)

Check warning on line 59 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L58-L59

Added lines #L58 - L59 were not covered by tests
else:
self._k_buffer[self._next_buffer_idx] += key
self._v_buffer[self._next_buffer_idx] += value
self._k_buffer[self._next_buffer_idx].add_(key)
self._v_buffer[self._next_buffer_idx].add_(value)

Check warning on line 62 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L61-L62

Added lines #L61 - L62 were not covered by tests

def get_buffers(self):
return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx]
Expand All @@ -84,23 +78,19 @@ def send_recv(self):


def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False):
if old_out is None and old_lse is None:
return block_out.to("float32"), block_lse.to("float32")

if second_chunk_only:
second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]])
second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]])
second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :]
second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :]

Check warning on line 83 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L82-L83

Added lines #L82 - L83 were not covered by tests
second_chunk_out, second_chunk_lse = update_out_and_lse(
second_chunk_out_, second_chunk_lse_, block_out, block_lse
second_chunk_out, second_chunk_lse, block_out, block_lse
)
paddle.assign(second_chunk_out, second_chunk_out_)
paddle.assign(second_chunk_lse, second_chunk_lse_)
old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse

Check warning on line 88 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L87-L88

Added lines #L87 - L88 were not covered by tests
return old_out, old_lse
else:
block_out, block_lse = block_out.to("float32"), block_lse.to("float32")
with paddle.amp.auto_cast(enable=False, dtype="bfloat16"):
lse = old_lse - F.log_sigmoid(old_lse - block_lse)
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid(

Check warning on line 91 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L91

Added line #L91 was not covered by tests
old_lse - block_lse
)


def get_chunk_id(rank, cp_size):
Expand Down Expand Up @@ -130,14 +120,10 @@ def balanced_ring_flash_attention_fwd_func(
comm_buffer = RingCommunicator(group, local_key, local_value)
local_q_seq_len = local_query.shape[1]

out, lse, k_cache, v_cache = None, None, dict(), dict()

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]

Check warning on line 126 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L126

Added line #L126 was not covered by tests
for step in range(cp_size):
block_k, block_v = comm_buffer.get_buffers()

Expand All @@ -159,16 +145,19 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 148 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L148

Added line #L148 was not covered by tests

if step == 0:
out, lse = block_out, block_lse

Check warning on line 151 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L150-L151

Added lines #L150 - L151 were not covered by tests
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)

Check warning on line 153 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L153

Added line #L153 was not covered by tests
else:
# block_k and block_v is from rank (group.rank - step) % cp_size
if step == 0:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, ""
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = block_out, block_lse

Check warning on line 160 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L159-L160

Added lines #L159 - L160 were not covered by tests
elif step > rank:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query_second_chunk,
Expand All @@ -182,16 +171,14 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)]
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 175 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L174-L175

Added lines #L174 - L175 were not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse, True)
else:
block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query,
block_k,
block_v,
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
fixed_seed_offset,
None,
dropout,
Expand All @@ -200,23 +187,19 @@ def balanced_ring_flash_attention_fwd_func(
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 190 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L190

Added line #L190 was not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
k_cache[step] = block_k
v_cache[step] = block_v

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# comm_buffer.wait()
paddle.device.synchronize()

out = out.to(local_query.dtype)
lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1])
return out, lse, k_cache, v_cache
return out.to(local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1])

Check warning on line 198 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L198

Added line #L198 was not covered by tests


def balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
local_query,
local_key,
Expand All @@ -240,17 +223,10 @@ def balanced_ring_flash_attention_bwd_func(
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)

if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_out_second_chunk = paddle.slice(
local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
query_grad_buffer_second_chunk = paddle.slice(
query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :]
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :]
out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :]

Check warning on line 229 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L226-L229

Added lines #L226 - L229 were not covered by tests

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
Expand All @@ -274,13 +250,13 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 253 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L253

Added line #L253 was not covered by tests
else:
if step == 0:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 259 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L259

Added line #L259 was not covered by tests
elif step > rank:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query_second_chunk,
Expand All @@ -294,12 +270,12 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer_second_chunk += block_q_grad
query_grad_buffer[:, local_q_seq_len // 2 :, :, :].add_(block_q_grad)

Check warning on line 273 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L273

Added line #L273 was not covered by tests
else:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query,
k_cache[step],
v_cache[step],
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
local_out,
lse,
fixed_seed_offset,
Expand All @@ -308,9 +284,12 @@ def balanced_ring_flash_attention_bwd_func(
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 287 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L287

Added line #L287 was not covered by tests

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# kv_comm_buffer.wait()
# if step != 0:
# grad_comm_buffer.wait()
paddle.device.synchronize()

grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
Expand Down Expand Up @@ -344,10 +323,10 @@ def forward(
if attn_mask is not None:
is_causal = False

out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func(
out, lse = balanced_ring_flash_attention_fwd_func(

Check warning on line 326 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L326

Added line #L326 was not covered by tests
group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training
)
ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache)
ctx.save_for_backward(query, key, value, out, lse, attn_mask)

Check warning on line 329 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L329

Added line #L329 was not covered by tests
ctx.group = group
ctx.fixed_seed_offset = fixed_seed_offset
ctx.dropout = dropout
Expand All @@ -356,7 +335,7 @@ def forward(

@staticmethod
def backward(ctx, out_grad):
query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor()
query, key, value, out, lse, attn_mask = ctx.saved_tensor()

Check warning on line 338 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L338

Added line #L338 was not covered by tests
group = ctx.group
fixed_seed_offset = ctx.fixed_seed_offset
dropout = ctx.dropout
Expand All @@ -366,19 +345,7 @@ def backward(ctx, out_grad):
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64)

query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
query,
key,
value,
out,
lse,
fixed_seed_offset,
attn_mask,
dropout,
is_causal,
group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal
)
if attn_mask is not None and not attn_mask.stop_gradient:
return query_grad, key_grad, value_grad, None
Expand Down
9 changes: 4 additions & 5 deletions tests/transformers/test_ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, us
)
ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask)

local_out.mean().backward()
ref_out.mean().backward()
local_out.backward()
ref_out.backward()

ref_local_query_grad = self.split_belanced_data(query.grad)
ref_local_key_grad = self.split_belanced_data(key.grad)
ref_local_value_grad = self.split_belanced_data(value.grad)

ref_local_out = self.split_belanced_data(ref_out)

rtol = 1e-04
atol = 5e-03
rtol = 1e-02
atol = 1e-02
np.testing.assert_allclose(
local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol
)
Expand Down

0 comments on commit 4cd6dbc

Please sign in to comment.