Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTO loss #410

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

KTO loss #410

wants to merge 6 commits into from

Conversation

vulkomilev
Copy link

Summary

This is the kto loss implemented by references from other projects

Details

I am not sure about the correctness (because this is my first PR) of the final results so I expect a lot of comments

Testing Done

I have done the basic testing inspired from cpo

@@ -126,7 +126,7 @@ def test_correctness(
input1, weight1, target, bias1, alpha=alpha
)
loss2 = LigerFusedLinearCPOFunction.apply(
input2, weight2, target, bias2, ignore_index, beta, alpha, True
input2, weight2, target, bias2, ignore_index, beta, alpha, False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing the test case for an unrelated alignment algo?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry my bad.

Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I think this code needs to be refactored to make things a bit cleaner and easier to understand. Could you also write out the equations for KTO in the description to the PR so that its easier for a reviewer to understand?

from torch.nn import functional as F


class LigerFusedLinearKTOPreferenceBase(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting this error
E RuntimeError: CUDA error: device-side assert triggered E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. E For debugging consider passing CUDA_LAUNCH_BLOCKING=1 E Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.

src/liger_kernel/chunked_loss/fused_linear_preference.py:210: RuntimeError
---------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------
NoneType: None
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [6,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [7,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [12,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [83,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [32,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [43,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [54,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [59,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [62,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
=============================================================================================================== short test summary info ===============================================================================================================
FAILED test/chunked_loss/test_kto_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] - RuntimeError: CUDA error: device-side assert triggered
================================================================================================================== 1 failed in 1.86s ============================================`

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do the equations and the formatting. Also I need two arguments 'reference_chosen_logps' and 'reference_rejected_logps' to my custom loss function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused do you still need this file? The base classes abstract preference_loss_fn does accept those two arguments, you can set beta=0 if it's not needed.

In case you need a completely new function signature, my advice would be to add a new overloaded function in the existing base class.

@vulkomilev
Copy link
Author

Okay code formatted and comment about the source of the loss added

Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vulkomilev can you please make sure that all unsued code is deleted and can you also confirm if

make checkstyle
make test works?

It'd be great if you can add the equations of KTO in the PRs description similar to #386

@vulkomilev
Copy link
Author

make checkstyle and make test works now.The commented code was removed and I have added the formula in kto_loss.py but I am not sure about the formmating

@pramodith
Copy link
Collaborator

@ByronHsu @shivam15s could either of you please take over reviewing this PR, have to switch my focus to other stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants