Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug fixes] Fix ring attention #8740

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 50 additions & 82 deletions paddlenlp/transformers/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@

def add_to_buffers(self, key, value):
if key.shape != self._k_buffer[self._next_buffer_idx].shape:
k_buffer_chunk = paddle.slice(
self._k_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[key.shape[1]]
)
v_buffer_chunk = paddle.slice(
self._v_buffer[self._next_buffer_idx], axes=[1], starts=[0], ends=[value.shape[1]]
)
k_buffer_chunk += key
v_buffer_chunk += value
self._k_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(key)
self._v_buffer[self._next_buffer_idx][:, : key.shape[1], :, :].add_(value)

Check warning on line 59 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L58-L59

Added lines #L58 - L59 were not covered by tests
else:
self._k_buffer[self._next_buffer_idx] += key
self._v_buffer[self._next_buffer_idx] += value
self._k_buffer[self._next_buffer_idx].add_(key)
self._v_buffer[self._next_buffer_idx].add_(value)

Check warning on line 62 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L61-L62

Added lines #L61 - L62 were not covered by tests

def get_buffers(self):
return self._k_buffer[self._next_buffer_idx], self._v_buffer[self._next_buffer_idx]
Expand All @@ -84,23 +78,21 @@


def update_out_and_lse(old_out, old_lse, block_out, block_lse, second_chunk_only=False):
if old_out is None and old_lse is None:
return block_out.to("float32"), block_lse.to("float32")

if second_chunk_only:
second_chunk_out_ = paddle.slice(old_out, axes=[1], starts=[old_out.shape[1] // 2], ends=[old_out.shape[1]])
second_chunk_lse_ = paddle.slice(old_lse, axes=[1], starts=[old_lse.shape[1] // 2], ends=[old_lse.shape[1]])
second_chunk_out = old_out[:, old_out.shape[1] // 2 :, :, :]
second_chunk_lse = old_lse[:, old_lse.shape[1] // 2 :, :, :]

Check warning on line 83 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L82-L83

Added lines #L82 - L83 were not covered by tests
second_chunk_out, second_chunk_lse = update_out_and_lse(
second_chunk_out_, second_chunk_lse_, block_out, block_lse
second_chunk_out, second_chunk_lse, block_out, block_lse
)
paddle.assign(second_chunk_out, second_chunk_out_)
paddle.assign(second_chunk_lse, second_chunk_lse_)
old_out[:, old_out.shape[1] // 2 :, :, :] = second_chunk_out
old_lse[:, old_lse.shape[1] // 2 :, :, :] = second_chunk_lse

Check warning on line 88 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L87-L88

Added lines #L87 - L88 were not covered by tests
return old_out, old_lse
else:
block_out, block_lse = block_out.to("float32"), block_lse.to("float32")
with paddle.amp.auto_cast(enable=False, dtype="bfloat16"):
lse = old_lse - F.log_sigmoid(old_lse - block_lse)
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), lse
block_out, block_lse = paddle.cast(block_out, "float32"), paddle.cast(block_lse, "float32")
with paddle.amp.auto_cast(enable=False):
return old_out - (old_out - block_out) * F.sigmoid(block_lse - old_lse), old_lse - F.log_sigmoid(

Check warning on line 93 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L91-L93

Added lines #L91 - L93 were not covered by tests
old_lse - block_lse
)


def get_chunk_id(rank, cp_size):
Expand Down Expand Up @@ -130,14 +122,10 @@
comm_buffer = RingCommunicator(group, local_key, local_value)
local_q_seq_len = local_query.shape[1]

out, lse, k_cache, v_cache = None, None, dict(), dict()

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]

Check warning on line 128 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L128

Added line #L128 was not covered by tests
for step in range(cp_size):
block_k, block_v = comm_buffer.get_buffers()

Expand All @@ -159,16 +147,19 @@
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 150 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L150

Added line #L150 was not covered by tests

if step == 0:
out, lse = block_out, block_lse

Check warning on line 153 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L152-L153

Added lines #L152 - L153 were not covered by tests
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)

Check warning on line 155 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L155

Added line #L155 was not covered by tests
else:
# block_k and block_v is from rank (group.rank - step) % cp_size
if step == 0:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query, block_k, block_v, fixed_seed_offset, None, dropout, True, False, not training, ""
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
out, lse = block_out, block_lse

Check warning on line 162 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L161-L162

Added lines #L161 - L162 were not covered by tests
elif step > rank:
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query_second_chunk,
Expand All @@ -182,16 +173,14 @@
not training,
"",
)
block_lse = paddle.slice(block_lse, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
block_lse = block_lse[:, :, 0 : (local_q_seq_len // 2)]
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 177 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L176-L177

Added lines #L176 - L177 were not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse, True)
else:
block_k = paddle.slice(block_k, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_v = paddle.slice(block_v, axes=[1], starts=[0], ends=[local_q_seq_len // 2])
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query,
block_k,
block_v,
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
fixed_seed_offset,
None,
dropout,
Expand All @@ -200,23 +189,19 @@
not training,
"",
)
block_lse = paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)
paddle.unsqueeze_(paddle.transpose_(block_lse, [0, 2, 1]), axis=-1)

Check warning on line 192 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L192

Added line #L192 was not covered by tests
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
k_cache[step] = block_k
v_cache[step] = block_v

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# comm_buffer.wait()
paddle.device.synchronize()

out = out.to(local_query.dtype)
lse = paddle.transpose_(paddle.squeeze_(lse, axis=-1), [0, 2, 1])
return out, lse, k_cache, v_cache
return paddle.cast(out, local_query.dtype), paddle.transpose_(paddle.squeeze(lse, axis=-1), [0, 2, 1])

Check warning on line 200 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L200

Added line #L200 was not covered by tests


def balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
local_query,
local_key,
Expand All @@ -240,17 +225,10 @@
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)

if is_causal:
local_query_second_chunk = paddle.slice(
local_query, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_out_second_chunk = paddle.slice(
local_out, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
lse_second_chunk = paddle.slice(lse, axes=[2], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
out_grad_second_chunk = paddle.slice(out_grad, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len])
query_grad_buffer_second_chunk = paddle.slice(
query_grad_buffer, axes=[1], starts=[local_q_seq_len // 2], ends=[local_q_seq_len]
)
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :]
local_out_second_chunk = local_out[:, local_q_seq_len // 2 :, :, :]
lse_second_chunk = lse[:, :, local_q_seq_len // 2 :]
out_grad_second_chunk = out_grad[:, local_q_seq_len // 2 :, :, :]

Check warning on line 231 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L228-L231

Added lines #L228 - L231 were not covered by tests

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
Expand All @@ -274,13 +252,13 @@
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 255 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L255

Added line #L255 was not covered by tests
else:
if step == 0:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query, block_k, block_v, local_out, lse, fixed_seed_offset, None, out_grad, dropout, True
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 261 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L261

Added line #L261 was not covered by tests
elif step > rank:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query_second_chunk,
Expand All @@ -294,12 +272,12 @@
dropout,
False,
)
query_grad_buffer_second_chunk += block_q_grad
query_grad_buffer[:, local_q_seq_len // 2 :, :, :].add_(block_q_grad)

Check warning on line 275 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L275

Added line #L275 was not covered by tests
else:
block_q_grad, block_k_grad, block_v_grad = flash_attn_bwd(
local_query,
k_cache[step],
v_cache[step],
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
local_out,
lse,
fixed_seed_offset,
Expand All @@ -308,9 +286,12 @@
dropout,
False,
)
query_grad_buffer += block_q_grad
query_grad_buffer.add_(block_q_grad)

Check warning on line 289 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L289

Added line #L289 was not covered by tests

# TODO(zhangyuqin1998):batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
# if step != cp_size - 1:
# kv_comm_buffer.wait()
# if step != 0:
# grad_comm_buffer.wait()
paddle.device.synchronize()

grad_comm_buffer.add_to_buffers(block_k_grad, block_v_grad)
Expand All @@ -319,8 +300,7 @@
grad_comm_buffer.wait()
key_grad_buffer, value_grad_buffer = grad_comm_buffer.get_buffers()

dtype = local_query.dtype
return query_grad_buffer.to(dtype), key_grad_buffer.to(dtype), value_grad_buffer.to(dtype)
return query_grad_buffer, key_grad_buffer, value_grad_buffer

Check warning on line 303 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L303

Added line #L303 was not covered by tests


class RingFlashAttention(PyLayer):
Expand All @@ -344,10 +324,10 @@
if attn_mask is not None:
is_causal = False

out, lse, k_cache, v_cache = balanced_ring_flash_attention_fwd_func(
out, lse = balanced_ring_flash_attention_fwd_func(

Check warning on line 327 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L327

Added line #L327 was not covered by tests
group, query, key, value, fixed_seed_offset, attn_mask, dropout, is_causal, training
)
ctx.save_for_backward(query, key, value, out, lse, attn_mask, k_cache, v_cache)
ctx.save_for_backward(query, key, value, out, lse, attn_mask)

Check warning on line 330 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L330

Added line #L330 was not covered by tests
ctx.group = group
ctx.fixed_seed_offset = fixed_seed_offset
ctx.dropout = dropout
Expand All @@ -356,7 +336,7 @@

@staticmethod
def backward(ctx, out_grad):
query, key, value, out, lse, attn_mask, k_cache, v_cache = ctx.saved_tensor()
query, key, value, out, lse, attn_mask = ctx.saved_tensor()

Check warning on line 339 in paddlenlp/transformers/ring_flash_attention.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/ring_flash_attention.py#L339

Added line #L339 was not covered by tests
group = ctx.group
fixed_seed_offset = ctx.fixed_seed_offset
dropout = ctx.dropout
Expand All @@ -366,19 +346,7 @@
fixed_seed_offset = paddle.to_tensor([0, 0], place=paddle.CPUPlace(), dtype=paddle.int64)

query_grad, key_grad, value_grad = balanced_ring_flash_attention_bwd_func(
group,
k_cache,
v_cache,
out_grad,
query,
key,
value,
out,
lse,
fixed_seed_offset,
attn_mask,
dropout,
is_causal,
group, out_grad, query, key, value, out, lse, fixed_seed_offset, attn_mask, dropout, is_causal
)
if attn_mask is not None and not attn_mask.stop_gradient:
return query_grad, key_grad, value_grad, None
Expand Down
9 changes: 4 additions & 5 deletions tests/transformers/test_ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ def single_test(self, bsz, seq_len_per_device, head_num, head_dim, is_causal, us
)
ref_out = scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=attn_mask)

local_out.mean().backward()
ref_out.mean().backward()
local_out.backward()
ref_out.backward()

ref_local_query_grad = self.split_belanced_data(query.grad)
ref_local_key_grad = self.split_belanced_data(key.grad)
ref_local_value_grad = self.split_belanced_data(value.grad)

ref_local_out = self.split_belanced_data(ref_out)

rtol = 1e-04
atol = 5e-03
rtol = 1e-02
atol = 1e-02
np.testing.assert_allclose(
local_out.to("float32").numpy(), ref_local_out.to("float32").numpy(), rtol=rtol, atol=atol
)
Expand Down
Loading