diff --git a/.gitignore b/.gitignore index cf4226001..6fe9e4f20 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ site/ .venv/ venv/ .ipynb_checkpoints/ +.vscode/ # Misc .DS_Store diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 000000000..64f16e10d --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,30 @@ +## Benchmarking Liger Kernels + +Follow these steps to benchmark and visualize kernel performance: + +1. Create a benchmark script + - Add your script under `benchmark/scripts/` + - Name it according to the kernel (e.g., `benchmark_.py`) + +2. Run the benchmark + - Results will be saved to `benchmark/data/all_benchmark_data.csv` + + Example: Benchmarking KTO Loss + ```bash + cd benchmark + python scripts/benchmark_kto_loss.py + ``` + +3. Visualize results + - Use the visualization script with appropriate parameters + + Example: Visualizing KTO Loss benchmark results + ```bash + python benchmarks_visualizer.py \ + --kernel-name kto_loss \ + --metric-name memory \ + --kernel-operation-mode full + ``` + +4. View results + - Generated plots will be saved in `benchmark/visualizations/` \ No newline at end of file diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 4e966cab2..df342e0a6 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,8.2532958984375,8.235372543334961,8.274937629699707,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,16.888959884643555,16.879615783691406,16.898893356323242,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,32.13854217529297,32.12795639038086,32.149131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,64.81161499023438,64.81161499023438,64.81161499023438,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,128.68646240234375,128.68646240234375,128.68646240234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.146656036376953,7.143622398376465,7.152345657348633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,12.538240432739258,12.521356582641602,12.540371894836426,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,26.29542350769043,25.303590774536133,26.88591957092285,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,49.26508712768555,49.26508712768555,49.26508712768555,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,98.9525146484375,98.9525146484375,98.9525146484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1 +kto_loss,liger,full,speed,ms,B,Batch Size (B),2,9.005151748657227,8.97766399383545,9.046483039855957,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1 +kto_loss,liger,full,speed,ms,B,Batch Size (B),4,19.108863830566406,19.09713363647461,19.185260772705078,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1 +kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.80137634277344,32.775360107421875,32.827388763427734,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1 +kto_loss,liger,full,speed,ms,B,Batch Size (B),16,65.46678161621094,65.46678161621094,65.46678161621094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1 +kto_loss,liger,full,speed,ms,B,Batch Size (B),32,129.91734313964844,129.91734313964844,129.91734313964844,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,16.091487884521484,14.86076831817627,16.23084831237793,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,28.04204750061035,28.03957176208496,28.055641174316406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,54.70073699951172,54.70073699951172,54.70073699951172,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,108.09929656982422,108.09929656982422,108.09929656982422,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,215.1945343017578,215.1945343017578,215.1945343017578,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1 +kto_loss,liger,full,memory,MB,B,Batch Size (B),2,3037.75390625,3037.75390625,3037.75390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1 +kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3800.0126953125,3800.0126953125,3800.0126953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1 +kto_loss,liger,full,memory,MB,B,Batch Size (B),8,4565.28076171875,4565.28076171875,4565.28076171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1 +kto_loss,liger,full,memory,MB,B,Batch Size (B),16,4589.31787109375,4589.31787109375,4589.31787109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1 +kto_loss,liger,full,memory,MB,B,Batch Size (B),32,4637.39208984375,4637.39208984375,4637.39208984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4793.7626953125,4793.7626953125,4793.7626953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6551.2978515625,6551.2978515625,6551.2978515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,10063.3681640625,10063.3681640625,10063.3681640625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,17093.5078125,17093.5078125,17093.5078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,31153.7890625,31153.7890625,31153.7890625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1 diff --git a/benchmark/scripts/benchmark_kto_loss.py b/benchmark/scripts/benchmark_kto_loss.py new file mode 100644 index 000000000..aaffc4177 --- /dev/null +++ b/benchmark/scripts/benchmark_kto_loss.py @@ -0,0 +1,264 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.utils import infer_device + +device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchKTOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + from test.chunked_loss.test_kto_loss import HFKTOLoss + + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) + self.kto_loss = HFKTOLoss( + ignore_index=ignore_index, beta=beta, use_ref_model=True + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y): + return self.kto_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + )[0] + + +class LigerKTOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) + self.kto_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, beta=beta, use_ref_model=True + ) + + def forward(self, x, ref_x, y): + return self.kto_loss( + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, + )[0] + + +def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_kto_loss = TorchKTOLoss( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + liger_kto_loss = LigerKTOLoss( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss(_input, ref_input, target) + elif provider == "huggingface": + return torch_kto_loss(_input, ref_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_kto_loss = TorchKTOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + liger_kto_loss = LigerKTOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss(_input, ref_input, target) + elif provider == "huggingface": + return torch_kto_loss(_input, ref_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "kto_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kto_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + + run_benchmarks( + bench_test_fn=bench_memory_kto_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/README.md b/src/liger_kernel/chunked_loss/README.md index 15ab24543..1dd7037f2 100644 --- a/src/liger_kernel/chunked_loss/README.md +++ b/src/liger_kernel/chunked_loss/README.md @@ -1,6 +1,6 @@ # Liger FlexChunkLoss: Alignment and Distillation loss -Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases. +Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases. ### User interface diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 238bdded9..87f3887b5 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,4 +1,5 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index 5a51d3f72..a10398400 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,5 +1,6 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction @@ -7,3 +8,4 @@ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply +liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..cb82fbb5f 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -28,6 +28,8 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, + unpaired=False, + preference_labels=None, use_ref_model=False, ref_input=None, ref_weight=None, @@ -59,6 +61,10 @@ def forward( compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. + unpaired (bool): Whether the inputs are unpaired (chosen and rejected). + Some loss functions that don't use paired preference, like KTO, can set this to True. + preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples. + Shape: (batch_size,). Required if unpaired is True. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Other possible arguments that a loss function might need @@ -181,21 +187,59 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if compiled: fused_fwd_bwd = torch.compile(fused_fwd_bwd) - len_chosen = target.shape[0] // 2 - chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) - _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) - _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) - _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) - _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + if not unpaired: + len_chosen = target.shape[0] // 2 + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + _chosen_input_chunks = torch.chunk( + _input[:len_chosen], chunks=chunks, dim=0 + ) + _chosen_target_chunks = torch.chunk( + target[:len_chosen], chunks=chunks, dim=0 + ) + _rejected_input_chunks = torch.chunk( + _input[len_chosen:], chunks=chunks, dim=0 + ) + _rejected_target_chunks = torch.chunk( + target[len_chosen:], chunks=chunks, dim=0 + ) - if use_ref_model: - _ref_chosen_input_chunks = torch.chunk( - ref_input[:len_chosen], chunks=chunks, dim=0 + if use_ref_model: + _ref_chosen_input_chunks = torch.chunk( + ref_input[:len_chosen], chunks=chunks, dim=0 + ) + _ref_rejected_input_chunks = torch.chunk( + ref_input[len_chosen:], chunks=chunks, dim=0 + ) + else: + # When not paired, use labels to separate chosen and rejected + assert ( + preference_labels is not None + ), "preference_labels must be provided when unpaired=True" + chosen_mask = preference_labels == 1 + rejected_mask = ~chosen_mask + + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + _chosen_input_chunks = torch.chunk( + _input[chosen_mask], chunks=chunks, dim=0 + ) + _chosen_target_chunks = torch.chunk( + target[chosen_mask], chunks=chunks, dim=0 + ) + _rejected_input_chunks = torch.chunk( + _input[rejected_mask], chunks=chunks, dim=0 ) - _ref_rejected_input_chunks = torch.chunk( - ref_input[len_chosen:], chunks=chunks, dim=0 + _rejected_target_chunks = torch.chunk( + target[rejected_mask], chunks=chunks, dim=0 ) + if use_ref_model: + _ref_chosen_input_chunks = torch.chunk( + ref_input[chosen_mask], chunks=chunks, dim=0 + ) + _ref_rejected_input_chunks = torch.chunk( + ref_input[rejected_mask], chunks=chunks, dim=0 + ) + for ( chosen_input_chunk, rejected_input_chunk, @@ -283,6 +327,7 @@ def chunk_forward( ignore_index=-100, compute_nll_loss=True, ): + # Data is already properly stacked (chosen then rejected) by the forward method len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() if bias is not None: diff --git a/src/liger_kernel/chunked_loss/kto_loss.py b/src/liger_kernel/chunked_loss/kto_loss.py new file mode 100644 index 000000000..181c6d11c --- /dev/null +++ b/src/liger_kernel/chunked_loss/kto_loss.py @@ -0,0 +1,219 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +class LigerFusedLinearKTOFunction(LigerFusedLinearPreferenceBase): + + @staticmethod + def preference_loss_fn( + chosen_logps, + rejected_logps, + full_target, + ref_chosen_logps=None, + ref_rejected_logps=None, + beta=0.1, + policy_KL_logps=None, + ref_KL_logps=None, + ): + """ + Implements the Kahneman-Tversky Optimization (KTO) loss function. + Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization" + https://arxiv.org/abs/2402.01306 + + KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory) + from behavioral economics, which models how humans make decisions under uncertainty. + The loss function is asymmetric, treating gains and losses differently, similar to + human decision-making patterns. + + Formula: + When y is chosen: + L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y)) + When y is rejected: + L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)])) + + Where: + - σ: Sigmoid function + - β: Temperature parameter controlling the strength of the preference signal + - π(x): Policy (current model) + - π₀(x): Reference policy (reference model) + - KL(π||π₀)_y: KL divergence estimated using the rejected response y + + The loss encourages the model to: + 1. Assign higher probability to chosen responses + 2. Assign lower probability to rejected responses + 3. Maintain reasonable distance from the reference model + + Args: + chosen_logps: Log probabilities of chosen tokens (batch_size,) + rejected_logps: Log probabilities of rejected tokens (batch_size,) + full_target: Non chunked full target tensor + ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) + ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) + beta: Weight for the direct preference loss + policy_KL_logps: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,) + ref_KL_logps: KL divergence between the reference model and the policy model for the chosen responses. Shape: (batch_size,) + + Returns: + Tuple of (loss, chosen_rewards, rejected_rewards): + - loss: The KTO loss value + - chosen_rewards: Reward signals for chosen responses (detached) + - rejected_rewards: Reward signals for rejected responses (detached) + """ + if ref_chosen_logps is None: + ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) + if ref_rejected_logps is None: + ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) + + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + + if policy_KL_logps is None: + policy_KL_logps = torch.tensor(0.0, device=chosen_logps.device) + if ref_KL_logps is None: + ref_KL_logps = torch.tensor(0.0, device=chosen_logps.device) + + kl = policy_KL_logps - ref_KL_logps + + losses = torch.cat( + ( + 1 - F.sigmoid(beta * (chosen_logratios - kl)), + 1 - F.sigmoid(beta * (kl - rejected_logratios)), + ), + 0, + ) + + chosen_rewards = beta * chosen_logratios.detach() + rejected_rewards = beta * rejected_logratios.detach() + + return ( + # We don't divide by 2 because KTO Loss doesn't need pair-wise examples + losses.sum() / (full_target.shape[0]), + chosen_rewards, + rejected_rewards, + ) + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + preference_labels=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + compiled=True, + use_ref_model=True, + policy_KL_logps=None, + ref_KL_logps=None, + ): + return LigerFusedLinearPreferenceBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + compiled=compiled, + use_ref_model=use_ref_model, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + unpaired=True, # KTO loss functions use unpaired preference + preference_labels=preference_labels, + policy_KL_logps=policy_KL_logps, + ref_KL_logps=ref_KL_logps, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + return ( + *grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class LigerFusedLinearKTOLoss(torch.nn.Module): + """ + Fused linear layer with Kahneman-Tversky Optimization (KTO) loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compute_nll_loss: bool = True, + compiled: bool = True, + use_ref_model: bool = False, + policy_KL_logps: torch.FloatTensor = None, + ref_KL_logps: torch.FloatTensor = None, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss calculation + beta (float): Temperature parameter for the KTO loss + compute_nll_loss (bool): Whether to compute the NLL loss alongside KTO + compiled (bool): Whether to use compiled operations + use_ref_model (bool): Whether to use a reference model for the DPO loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compute_nll_loss = compute_nll_loss + self.compiled = compiled + self.use_ref_model = use_ref_model + self.policy_KL_logps = policy_KL_logps + self.ref_KL_logps = ref_KL_logps + + def forward( + self, + _input, + lin_weight, + target, + bias=None, + preference_labels=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + return LigerFusedLinearKTOFunction.apply( + _input, + lin_weight, + target, + bias, + preference_labels, + ref_input, + ref_weight, + ref_bias, + self.ignore_index, + self.beta, + self.compute_nll_loss, + self.compiled, + self.use_ref_model, + self.policy_KL_logps, + self.ref_KL_logps, + ) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0ac8faeb8..748f253af 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -17,9 +17,9 @@ class HFDPOLoss(HFAlignmentLoss): """ - Implementation of the Odds Ratio Preference Optimization (ORPO) loss, + Implementation of the Direct Preference Optimization (DPO) loss, adapted from Hugging Face's implementation. - Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py """ def __init__( diff --git a/test/chunked_loss/test_kto_loss.py b/test/chunked_loss/test_kto_loss.py new file mode 100644 index 000000000..75fcb3b2a --- /dev/null +++ b/test/chunked_loss/test_kto_loss.py @@ -0,0 +1,390 @@ +from test.utils import HFAlignmentLoss, assert_verbose_allclose, set_seed + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_kto +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction +from liger_kernel.utils import infer_device + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFKTOLoss(HFAlignmentLoss): + """ + Implementation of the Kahneman-Tversky Optimization (KTO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + policy_KL_logps: torch.FloatTensor = None, + ref_KL_logps: torch.FloatTensor = None, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + policy_KL_logps=policy_KL_logps, + ref_KL_logps=ref_KL_logps, + unpaired=True, + ) + # KL logps need to be passed into the Loss class since it requires a full model forward pass + # See paper https://arxiv.org/abs/2402.01306 (4.1. Derivation) + self.policy_KL_logps = policy_KL_logps + self.ref_KL_logps = ref_KL_logps + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ): + """Compute KTO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + Returns: + The losses tensor contains the KTO loss for each example in the batch. + """ + if self.policy_KL_logps is None: + self.policy_KL_logps = torch.zeros(1).to(device) + + if self.ref_KL_logps is None: + self.ref_KL_logps = torch.zeros(1).to(device) + + kl = (self.policy_KL_logps - self.ref_KL_logps).mean().clamp(min=0).detach() + + chosen_logratios = policy_chosen_logps - ref_chosen_logps + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + chosen_rewards = self.beta * chosen_logratios.detach() + + rejected_logratios = policy_rejected_logps - ref_rejected_logps + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + rejected_rewards = self.beta * rejected_logratios.detach() + + losses = torch.cat((chosen_losses, rejected_losses), 0) + return losses, chosen_rewards, rejected_rewards + + +class TorchLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + policy_KL_logps: torch.FloatTensor = None, + ref_KL_logps: torch.FloatTensor = None, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) + self.KTO_loss = HFKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + policy_KL_logps=policy_KL_logps, + ref_KL_logps=ref_KL_logps, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y, preference_labels): + return self.KTO_loss( + weight=self.lin.weight, + _input=x, + target=y, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + preference_labels=preference_labels, + ) + + +class LigerLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + policy_KL_logps: torch.FloatTensor = None, + ref_KL_logps: torch.FloatTensor = None, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) + self.KTO_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + policy_KL_logps=policy_KL_logps, + ref_KL_logps=ref_KL_logps, + ) + + def forward(self, x, ref_x, y, preference_labels): + return self.KTO_loss( + _input=x, + lin_weight=self.lin.weight, + target=y, + preference_labels=preference_labels, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (8, 47, 31, 123), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 2e-2, 5e-1), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta +): + B = 2 * B + # Create labels tensor with scattered True values + preference_labels = torch.zeros(B, dtype=torch.bool, device=device) + num_chosen = B // 2 # Keep same number of chosen examples + generator = torch.Generator(device=device).manual_seed(42) + chosen_indices = torch.randperm(B, generator=generator, device=device)[:num_chosen] + preference_labels[chosen_indices] = True + + torch_lm_head_KTO = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_KTO = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_KTO.lin.weight.data = liger_lm_head_KTO.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_KTO.ref_lin.weight.data = liger_lm_head_KTO.ref_lin.weight.data = ( + torch.randn(V, H, device=device, dtype=dtype) + ) + + if bias: + torch_lm_head_KTO.lin.bias.data = liger_lm_head_KTO.lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + if ref_bias: + torch_lm_head_KTO.ref_lin.bias.data = liger_lm_head_KTO.ref_lin.bias.data = ( + torch.randn(V, device=device, dtype=dtype) + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = ( + torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + ) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1, aggregated_aux_outputs1 = torch_lm_head_KTO( + x=input1, ref_x=ref_input, y=target, preference_labels=preference_labels + ) + loss2, aggregated_aux_outputs2 = liger_lm_head_KTO( + x=input2, ref_x=ref_input, y=target, preference_labels=preference_labels + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + assert len(aggregated_aux_outputs1) == len(aggregated_aux_outputs2) + + for i in range(len(aggregated_aux_outputs1)): + if i >= 5: # skip checking chosen_rewards and rejected_rewards + continue + assert_verbose_allclose( + aggregated_aux_outputs1[i], + aggregated_aux_outputs2[i], + atol=atol, + rtol=rtol, + ) + + loss1.backward() + loss2.backward() + + # Passed + assert_verbose_allclose(input1, input2, atol=atol, rtol=rtol) + assert_verbose_allclose(torch_lm_head_KTO.lin.weight, liger_lm_head_KTO.lin.weight, atol=atol, rtol=rtol) + + if bias: + assert_verbose_allclose(torch_lm_head_KTO.lin.bias, liger_lm_head_KTO.lin.bias, atol=atol, rtol=rtol) + + # Failed + # assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_KTO.lin.weight.grad, + liger_lm_head_KTO.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_KTO.lin.bias.grad, + liger_lm_head_KTO.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): + B = 2 * B + + # Create labels tensor with scattered True values + preference_labels = torch.zeros(B, dtype=torch.bool, device=device) + num_chosen = B // 2 # Keep same number of chosen examples + generator = torch.Generator(device=device).manual_seed(42) + chosen_indices = torch.randint( + 0, B, (num_chosen,), generator=generator, device=device + ) + preference_labels[chosen_indices] = True + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = ( + torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + ) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + loss1, aggregated_aux_outputs1 = LigerFusedLinearKTOFunction.apply( + input1, + weight1, + target, + bias1, + preference_labels, + ref_input, + ref_weight1, + ref_bias1, + ) + loss2, aggregated_aux_outputs2 = liger_fused_linear_kto( + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + preference_labels, + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..2ab9a37be 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,14 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + unpaired: bool = False, + **kwargs, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.unpaired = unpaired @abstractmethod def alignment_loss(self): @@ -402,6 +405,7 @@ def get_ref_logps( target: torch.LongTensor, ref_bias: torch.FloatTensor, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, ): """Compute the log probabilities of the given labels under the given reference model.""" @@ -411,10 +415,16 @@ def get_ref_logps( ref_all_logps = self.get_batch_logps( ref_logits, target, average_log_prob=average_log_prob ) - return ( - ref_all_logps[: _input.shape[0] // 2], - ref_all_logps[_input.shape[0] // 2 :], - ) + + if self.unpaired and preference_labels is not None: + # Split based on preference labels + return ref_all_logps[preference_labels], ref_all_logps[~preference_labels] + else: + # Original paired behavior - split in half + return ( + ref_all_logps[: _input.shape[0] // 2], + ref_all_logps[_input.shape[0] // 2 :], + ) def concatenated_forward( self, @@ -423,6 +433,7 @@ def concatenated_forward( target: torch.LongTensor, bias: torch.FloatTensor = None, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor ]: @@ -458,11 +469,19 @@ def cross_entropy_loss(logits, labels): average_log_prob=average_log_prob, ) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if self.unpaired and preference_labels is not None: + # Split based on labels tensor + chosen_logps = all_logps[preference_labels] + rejected_logps = all_logps[~preference_labels] + chosen_logits = all_logits[preference_labels] + rejected_logits = all_logits[~preference_labels] + else: + # Original paired behavior - split in half + len_chosen = _input.shape[0] // 2 + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, @@ -482,11 +501,12 @@ def get_batch_loss_metrics( ref_weight: torch.FloatTensor = None, ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, ): """Compute the loss metrics for the given batch of inputs for train or test.""" forward_output = self.concatenated_forward( - _input, weight, target, bias, average_log_prob + _input, weight, target, bias, average_log_prob, preference_labels ) ( policy_chosen_logps, @@ -499,7 +519,12 @@ def get_batch_loss_metrics( loss_kwargs = {} if self.use_ref_model: ref_chosen_logps, ref_rejected_logps = self.get_ref_logps( - ref_input, ref_weight, target, ref_bias, average_log_prob + ref_input, + ref_weight, + target, + ref_bias, + average_log_prob, + preference_labels, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps