Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 15, 2024
1 parent 0473e22 commit 2e6ded2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _test_correctness_with_weight_once(
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
output.backward(gradient=torch.ones_like(output))
output2.backward(gradient=torch.ones_like(output))
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


Expand Down Expand Up @@ -397,11 +397,11 @@ def _test_correctness_with_weight_with_other_params_once(
softcap * torch.tanh(_input.to(torch.float32) / softcap), target
).to(dtype)
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
output.backward(gradient=torch.ones_like(output))
output2.backward(gradient=torch.ones_like(output))
assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_not_last_layer_once(
Expand Down Expand Up @@ -831,7 +831,7 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r
(3, 423, 3200),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize("reduction", ["sum", "mean", "none"])
@pytest.mark.parametrize(
"ignore_index, lse_square_scale, label_smoothing, softcap",
[
Expand Down

0 comments on commit 2e6ded2

Please sign in to comment.