Skip to content

Commit

Permalink
Merge pull request #43 from hypnopump/fix_beta_gradient
Browse files Browse the repository at this point in the history
[DRAFT] Beta gradient does not match
  • Loading branch information
yzhangcs authored Aug 13, 2024
2 parents 10b841f + 141da4b commit d28b3e1
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/ops/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("dtype", [torch.float])
def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: torch.dtype):
torch.manual_seed(17)
q = torch.randn(B, H, T, D, dtype=dtype)
k = torch.nn.functional.normalize(torch.randn(B, H, T, D, dtype=dtype), p=2, dim=-1)
v = torch.randn(B, H, T, D, dtype=dtype)
beta = torch.rand(B, H, T, dtype=dtype).sigmoid()
beta = torch.rand(B, H, T, D, dtype=dtype).sigmoid()
q, k, v, beta = map(lambda x: x.cuda().requires_grad_(True), (q, k, v, beta))
do = torch.rand_like(v)

Expand All @@ -31,11 +32,12 @@ def test_beta_scalar_vector_equivalence(B: int, H: int, T: int, D: int, dtype: t
q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None

assert o.allclose(o2, 0, 1e-3), f"Diff: {torch.abs(o - o2).max()}"
assert q_grad.allclose(q_grad2, 0, 1e-3), f"Diff: {torch.abs(q_grad - q_grad2).max()}"
assert k_grad.allclose(k_grad2, 0, 1e-3), f"Diff: {torch.abs(k_grad - k_grad2).max()}"
assert v_grad.allclose(v_grad2, 0, 1e-3), f"Diff: {torch.abs(v_grad - v_grad2).max()}"
assert beta_grad.allclose(beta_grad2, 0, 1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}"
assert o.allclose(o2, rtol=0, atol=2e-5), f"Diff: {torch.abs(o - o2).max()}"
assert q_grad.allclose(q_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(q_grad - q_grad2).max()}"
assert k_grad.allclose(k_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(k_grad - k_grad2).max()}"
assert v_grad.allclose(v_grad2, rtol=0, atol=2e-5), f"Diff: {torch.abs(v_grad - v_grad2).max()}"
# FIXME: this gradient does not match when beta a vector. matches when a scalar.
assert beta_grad.allclose(beta_grad2, rtol=0, atol=1e-3), f"Diff: {torch.abs(beta_grad - beta_grad2).max()}"


@pytest.mark.parametrize("B", [4])
Expand Down

0 comments on commit d28b3e1

Please sign in to comment.