diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 2a980c69e..b0092d5ef 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -285,11 +285,12 @@ def cross_entropy_forward( num_warps=32 if not is_hip() else 16, ) - loss = torch.sum(loss_1d) - if return_z_loss == _TRUE.value: - z_loss = torch.sum(z_loss_1d) + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss == _TRUE.value else None else: - z_loss = None + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None return loss, z_loss, _input diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index c5e371654..5bb59d718 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -87,8 +87,8 @@ def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, r output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -118,8 +118,8 @@ def _test_correctness_with_ignore_index_once( assert torch.allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -199,8 +199,8 @@ def _test_correctness_with_softcap_once( assert torch.allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -325,8 +325,8 @@ def _test_correctness_not_last_layer_once( loss1 = output * 3 loss2 = output2 * 3 - loss1.backward() - loss2.backward() + loss1.backward(gradient=torch.ones_like(output)) + loss2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -384,7 +384,7 @@ def _test_correctness_functional( (3, 423, 32000), # weird shapes ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -432,7 +432,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): (3, 423, 32000, -123), ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -532,7 +532,7 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( (3, 423, 32000, 30.0), ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -700,7 +700,7 @@ def test_correctness_with_z_loss_with_other_params_once( (3, 423, 32000), ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [