From ebd53035306685aaad8b7df1582ce77b66a23be1 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:46:47 +0000 Subject: [PATCH] Add Chunked SimPO Loss (#386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR adds the Simple Preference Optimization Loss function. The only difference between SimPO and CPO is a margin term `gamma` which specifies that the preferred response should be atleast gamma logit points better than the losing response. $$SimPOLoss = -\log(\sigma(\beta\log(\pi_\theta(y_c|x)) - \beta\log(\pi_\theta(y_r|x)) - \gamma))$$ Note that SimPO explicitly specifies that $$\pi_\theta(y|x)$$ needs to be normalized by length, unlike DPO. This corresponds to Eq 6 in the [paper](https://arxiv.org/pdf/2405.14734). ## Testing Done GPU A100-80G-SXM ![Screenshot 2024-11-15 at 2 38 23 PM](https://github.com/user-attachments/assets/ac126f94-ebd8-4457-a4a2-53832699af4c) ![Screenshot 2024-11-15 at 2 38 37 PM](https://github.com/user-attachments/assets/e539e9cd-f66a-42dd-8b43-3ae44dcd42a0) - Hardware Type: - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu --- benchmark/data/all_benchmark_data.csv | 24 +++ benchmark/scripts/benchmark_simpo_loss.py | 191 ++++++++++++++++++ .../chunked_loss/fused_linear_preference.py | 8 +- src/liger_kernel/chunked_loss/orpo_loss.py | 2 +- src/liger_kernel/chunked_loss/simpo_loss.py | 64 ++++++ test/chunked_loss/test_cpo_loss.py | 12 +- test/chunked_loss/test_simpo_loss.py | 78 +++++++ test/utils.py | 10 +- 8 files changed, 381 insertions(+), 8 deletions(-) create mode 100644 benchmark/scripts/benchmark_simpo_loss.py create mode 100644 src/liger_kernel/chunked_loss/simpo_loss.py create mode 100644 test/chunked_loss/test_simpo_loss.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 6e5fd4ce0..ed25905cd 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -691,3 +691,27 @@ fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.31445 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py new file mode 100644 index 000000000..457f6f2e8 --- /dev/null +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -0,0 +1,191 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadSimPO(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + from test.chunked_loss.test_cpo_loss import HFCPOLoss + + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics + + def forward(self, x, y): + return self.simpo_loss(x, self.lin.weight, y) + + +class LigerLMHeadSimPO(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.simpo_loss = LigerFusedLinearSimPOFunction.apply + + def forward(self, x, y): + return self.simpo_loss(x, self.lin.weight, y) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_simpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_simpo(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_simpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + + torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_simpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_simpo(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_simpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_simpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_simpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c43caf839..73981dff4 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -32,6 +32,7 @@ def forward( alpha=1.0, beta=0.1, compiled=True, + **loss_kwargs, ): """ Base class for fused linear layer with preference loss. @@ -49,6 +50,7 @@ def forward( alpha (float): Weight for the NLL loss. beta (float): Weight for the odds ratio loss. compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU CHUNK_SIZE = chunk_size @@ -68,6 +70,7 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, full_target=target, + **loss_kwargs, ) def accumulate_chunk(input_chunk, target_chunk): @@ -94,6 +97,9 @@ def accumulate_chunk(input_chunk, target_chunk): loss_acc.add_(chunk_loss) return chunk_grad_input + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + len_chosen = target.shape[0] // 2 _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) @@ -116,8 +122,6 @@ def accumulate_chunk(input_chunk, target_chunk): [chosen_target_chunk, rejected_target_chunk], dim=0 ) - if compiled: - accumulate_chunk = torch.compile(accumulate_chunk) grad_input = accumulate_chunk(input_chunk, target_chunk) grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index 0ff146d5d..a921f3f11 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -34,7 +34,7 @@ def forward( ignore_index=-100, beta=0.1, compute_nll_loss=True, - compiled=True, + compiled=False, ): """ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py new file mode 100644 index 000000000..eff581406 --- /dev/null +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -0,0 +1,64 @@ +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): + + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + gamma (float): The simpo gamma, margin term. + """ + logits = beta * (chosen_logps - rejected_logps) - gamma + loss = F.logsigmoid(logits).mean() + return loss + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + alpha=1.0, + compute_nll_loss=False, + compiled=True, + gamma=0.5, + ): + """ + Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 + Handles both the forward and backward pass of the final linear layer with SimPO loss. + Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. + """ + + return LigerFusedLinearPreferenceBase.forward( + ctx, + _input, + weight, + target, + bias, + loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, + ignore_index=ignore_index, + alpha=alpha, + beta=beta, + compiled=compiled, + gamma=gamma, + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None, None, None diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 9211f98fd..b8fce9e06 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -22,11 +22,14 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, label_smoothing: float = 0.0, + simpo_gamma: float = 0.5, + loss_type: str = "sigmoid", ): super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index) # Sigmoid defaults to the CPO loss defined in the paper listed above. - self.loss_type = "sigmoid" + self.loss_type = loss_type self.label_smoothing = label_smoothing + self.simpo_gamma = simpo_gamma def alignment_loss( self, @@ -55,6 +58,12 @@ def alignment_loss( F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + F.logsigmoid(-self.beta * logits) * self.label_smoothing ) + elif self.loss_type == "simpo": + logits = logits - (self.simpo_gamma / self.beta) + losses = ( + F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) else: raise ValueError( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']" @@ -66,7 +75,6 @@ def alignment_loss( @pytest.mark.parametrize( "B, T, H, V", [ - # (1, 2, 12, 128), (8, 128, 1024, 4096), (3, 47, 31, 123), # random shape ], diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py new file mode 100644 index 000000000..727aaa56e --- /dev/null +++ b/test/chunked_loss/test_simpo_loss.py @@ -0,0 +1,78 @@ +from test.chunked_loss.test_cpo_loss import HFCPOLoss +from test.utils import assert_verbose_allclose, set_seed + +import pytest +import torch + +from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction + +# set random seed globally +set_seed() + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] +) +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma +): + B = 2 * B # SimPO loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HFCPOLoss( + ignore_index=ignore_index, beta=beta, simpo_gamma=gamma, loss_type="simpo" + ).get_batch_loss_metrics(input1, weight1, target, bias1) + loss2 = LigerFusedLinearSimPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, 1.0, True, True, gamma + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 8ac0309fb..f1b919687 100644 --- a/test/utils.py +++ b/test/utils.py @@ -406,6 +406,7 @@ def concatenated_forward( weight: torch.FloatTensor, target: torch.LongTensor, bias: torch.FloatTensor = None, + average_log_prob: bool = True, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor ]: @@ -438,7 +439,7 @@ def cross_entropy_loss(logits, labels): all_logps = self.get_batch_logps( all_logits, target, - average_log_prob=True, + average_log_prob=average_log_prob, ) chosen_logps = all_logps[:len_chosen] @@ -462,10 +463,13 @@ def get_batch_loss_metrics( target: torch.LongTensor, bias: torch.FloatTensor = None, alpha: float = 1.0, + average_log_prob: bool = True, ): """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward(_input, weight, target, bias) + forward_output = self.concatenated_forward( + _input, weight, target, bias, average_log_prob + ) ( policy_chosen_logps, policy_rejected_logps, @@ -475,6 +479,6 @@ def get_batch_loss_metrics( ) = forward_output[:5] losses = self.alignment_loss(policy_chosen_logps, policy_rejected_logps) - # full ORPO loss + # full loss loss = policy_nll_loss * alpha - losses.mean() return loss