Skip to content

Commit

Permalink
I dont' get why dV fails
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Aug 10, 2023
1 parent 7908525 commit 279458d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
18 changes: 17 additions & 1 deletion test/test_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,27 @@ def test_flash_masked_block(dtype=torch.float16):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask
)

ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore

tri_out.half()
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# Check attn_bias equivalence
atol = 2e-2 * 6
torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
torch.testing.assert_close(ref_mask, mask.half(), atol=4e-2, rtol=0)

breakpoint()
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)

if __name__ == "__main__":
pytest.main([__file__])
8 changes: 7 additions & 1 deletion transformer_nuggets/flash/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _fwd_kernel(
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute scaling constant ---
row_max = tl.max(qk, 1)
masked_out_rows= masked_row(row_max)
masked_out_rows = masked_row(row_max)
m_i_new = tl.maximum(m_i, row_max)
# TODO FIX ME
# alpha = tl.math.exp2(m_i - m_i_new)
Expand Down Expand Up @@ -263,11 +263,17 @@ def _bwd_kernel(
qk = rel_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == BiasMode.alibi:
qk = alibi_attention_triton(qk, offs_m_curr[:, None], (offs_n[None, :]), off_hz%H, H)
elif BIAS_CHOICE == BiasMode.inverse_causal:
# This should only be used for debugging
qk = inverse_causal_mask_triton(qk, offs_m[:, None], (start_n + offs_n[None, :]), off_hz%H, H)
# ~~~~~~~~~~~~~~~~~~~ This is the end of mask stuff ~~~~~~~~~~~~~~~~~~~
l_i = tl.load(l_ptrs + offs_m_curr)
row_max = tl.max(qk, 1)
masked_out_rows= masked_row(row_max)
# TODO fix me
# p = tl.math.exp2(qk - l_i[:, None])
p = tl.math.exp(qk - l_i[:, None])
p = tl.where(masked_out_rows[:, None], 0, p)
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
Expand Down

0 comments on commit 279458d

Please sign in to comment.