From 8bcde33b9b6b719366d74c875581b7c0f206246d Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 11 Nov 2024 16:19:29 +0800 Subject: [PATCH 1/6] Add chunked DPO Loss Signed-off-by: Austin Liu Fix benchmark script --- benchmark/scripts/benchmark_dpo_loss.py | 229 ++++++++++++++++++ src/liger_kernel/alignment/dpo_loss.py | 301 ++++++++++++++++++++++++ test/alignment/test_dpo_loss.py | 75 ++++++ 3 files changed, 605 insertions(+) create mode 100644 benchmark/scripts/benchmark_dpo_loss.py create mode 100644 src/liger_kernel/alignment/dpo_loss.py create mode 100644 test/alignment/test_dpo_loss.py diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py new file mode 100644 index 000000000..fa492ab9b --- /dev/null +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -0,0 +1,229 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction + + +class TorchDPOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.1, + ignore_index: int = -100, + bias: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index) + + def forward(self, x, target): + return self.dpo_loss.get_batch_loss_metrics( + x, + self.lin.weight, + target, + self.lin.bias if hasattr(self.lin, "bias") else None, + ) + + +class LigerDPOLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + beta: float = 0.1, + ignore_index: int = -100, + bias: bool = False, + ): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=bias, dtype=dtype + ) + self.beta = beta + self.ignore_index = ignore_index + + def forward(self, x, target): + return LigerFusedLinearDPOFunction.apply( + x, + self.lin.weight, + target, + self.lin.bias if hasattr(self.lin, "bias") else None, + self.ignore_index, + self.beta, + True, + ) + + +def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x // 2 + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + device = "cuda" + torch_dpo_loss = TorchDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + liger_dpo_loss = LigerDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + + # Input shape: [B*2, T, H] + _input = torch.randn(B * 2, T, H, requires_grad=True, dtype=dtype, device=device) + # Target shape: [B*2, T] + target = torch.randint(V, (B * 2, T), dtype=torch.long, device=device) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_dpo_loss(_input, target) + elif provider == "huggingface": + return torch_dpo_loss(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x // 2 + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + torch_dpo_loss = TorchDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + liger_dpo_loss = LigerDPOLoss( + H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias + ).to(device) + + # Input shape: [B*2, T, H] + input1 = torch.randn( + B * 2, T, H, device=device, dtype=dtype + ) # .detach().clone().requires_grad_(True) + input2 = torch.randn( + B * 2, T, H, device=device, dtype=dtype + ) # .detach().clone().requires_grad_(True) + + # Target shape: [B*2, T] + target = torch.randint(0, V, (B * 2, T), device=device, dtype=torch.long) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + def fwd(): + if provider == "liger": + return liger_dpo_loss(input1, target) + elif provider == "huggingface": + return torch_dpo_loss(input2, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[input1, input2], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "dpo_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_dpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + + run_benchmarks( + bench_test_fn=bench_memory_dpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/alignment/dpo_loss.py b/src/liger_kernel/alignment/dpo_loss.py new file mode 100644 index 000000000..6d54cb45c --- /dev/null +++ b/src/liger_kernel/alignment/dpo_loss.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from triton import next_power_of_2 + +from liger_kernel.ops.utils import element_mul_kernel + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 +# Setting limit as 65536 for better performance due to less register spilling +MAX_FUSED_SIZE = 65536 // 2 + + +def dpo_loss(chosen_logps, rejected_logps, beta=0.1): + """ + Compute DPO loss (Direct Preference Optimization). + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Temperature parameter for the DPO loss. + """ + logits_diff = (chosen_logps - rejected_logps) / beta + losses = -F.logsigmoid(logits_diff) + return losses.sum() + + +class LigerFusedLinearDPOFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + compiled=True, + ): + """ + Fused linear layer with DPO (Direct Preference Optimization) loss. + Handles both the forward and backward pass of the final linear layer with DPO loss. + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + ignore_index (int): Index to ignore for loss computation. + beta (float): Temperature parameter for the DPO loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = 1 + + def _compute_dpo_loss(input_chunk, weight, target_chunk, bias=None): + len_chosen_chunk = target_chunk.shape[0] // 2 + + unnorm_logits = input_chunk @ weight.t() # chunk_size x V + if bias is not None: + unnorm_logits = unnorm_logits + bias + unnorm_logits = unnorm_logits.float() + norm_logits = F.log_softmax(unnorm_logits, dim=-1) + + # Compute NLL loss for chosen responses + chosen_nll_loss = F.nll_loss( + norm_logits[:len_chosen_chunk].view(-1, norm_logits.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + chosen_nll_loss = ( + chosen_nll_loss / (target[: target.shape[0] // 2] != ignore_index).sum() + ) + + # Compute log probabilities for both chosen and rejected + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + per_token_logps = norm_logits.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + -1 + ) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + + # Compute DPO loss + preference_loss = dpo_loss(chosen_logps, rejected_logps, beta=beta) + preference_loss = preference_loss / (target.shape[0] // 2) + + # Total loss combines NLL and DPO loss + loss = chosen_nll_loss + preference_loss + return loss, (preference_loss, chosen_logps, rejected_logps) + + def compute_dpo_loss(input_chunk, weight, target_chunk, bias=None): + return _compute_dpo_loss(input_chunk, weight, target_chunk, bias) + + grad_weight = torch.zeros_like(weight) + grad_chosen_inputs = [] + grad_rejected_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None + loss_acc = torch.zeros((), device=_input.device) + + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + + def accumulate_chunk(input_chunk, target_chunk): + if bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + (chunk_dpo_loss, chunk_chosen_logps, chunk_rejected_logps), + ) = torch.func.grad_and_value( + compute_dpo_loss, argnums=(0, 1, 3), has_aux=True + )( + input_chunk, weight, target_chunk, bias + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + (chunk_dpo_loss, chunk_chosen_logps, chunk_rejected_logps), + ) = torch.func.grad_and_value( + compute_dpo_loss, argnums=(0, 1), has_aux=True + )( + input_chunk, weight, target_chunk + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + len_chosen = target.shape[0] // 2 + _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) + _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) + _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) + _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + + for ( + chosen_input_chunk, + rejected_input_chunk, + chosen_target_chunk, + rejected_target_chunk, + ) in zip( + _chosen_input_chunks, + _rejected_input_chunks, + _chosen_target_chunks, + _rejected_target_chunks, + ): + input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) + target_chunk = torch.cat( + [chosen_target_chunk, rejected_target_chunk], dim=0 + ) + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + grad_input = accumulate_chunk(input_chunk, target_chunk) + + grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) + grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :]) + + grad_inputs = grad_chosen_inputs + grad_rejected_inputs + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + BT, H = grad_input.view(-1, grad_input.shape[-1]).shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + return grad_input, grad_weight, None, grad_bias, None, None, None + + +class HF_DPO_Loss: + """ + Implementation of Direct Preference Optimization (DPO) loss, + adapted from the Hugging Face implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py + """ + + def __init__(self, ignore_index: int = -100, beta: float = 0.1): + self.ignore_index = ignore_index + self.beta = beta + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits and labels must have the same shape.") + + loss_mask = labels != self.ignore_index + labels = torch.where(labels == self.ignore_index, 0, labels) + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> torch.FloatTensor: + """Compute DPO loss for a batch of policy log probabilities.""" + logits_diff = (policy_chosen_logps - policy_rejected_logps) / self.beta + losses = -F.logsigmoid(logits_diff) + return losses + + def concatenated_forward( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ): + len_chosen = _input.shape[0] // 2 + + outputs = _input @ weight.t() + if bias is not None: + outputs = outputs + bias + all_logits = outputs.float() + + def cross_entropy_loss(logits, labels): + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], target[:len_chosen] + ) + + all_logps = self.get_batch_logps( + all_logits, + target, + average_log_prob=True, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + return chosen_logps, rejected_logps, chosen_nll_loss + + def get_batch_loss_metrics( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ): + policy_chosen_logps, policy_rejected_logps, nll_loss = ( + self.concatenated_forward(_input, weight, target, bias) + ) + + dpo_losses = self.dpo_loss(policy_chosen_logps, policy_rejected_logps) + loss = nll_loss + dpo_losses.mean() + return loss diff --git a/test/alignment/test_dpo_loss.py b/test/alignment/test_dpo_loss.py new file mode 100644 index 000000000..b55f0730f --- /dev/null +++ b/test/alignment/test_dpo_loss.py @@ -0,0 +1,75 @@ +from test.utils import assert_verbose_allclose, set_seed +from typing import Tuple + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction + +# set random seed globally +set_seed() + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): + B = 2 * B # dpo loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HF_DPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + input1, weight1, target, bias1 + ) + loss2 = LigerFusedLinearDPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, True + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) From be44081333164ebc871d92c62401b34455afa68a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 13 Nov 2024 14:07:10 +0800 Subject: [PATCH 2/6] Remove unused imports Signed-off-by: Austin Liu --- test/alignment/test_dpo_loss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/alignment/test_dpo_loss.py b/test/alignment/test_dpo_loss.py index b55f0730f..d9fed0502 100644 --- a/test/alignment/test_dpo_loss.py +++ b/test/alignment/test_dpo_loss.py @@ -1,9 +1,7 @@ from test.utils import assert_verbose_allclose, set_seed -from typing import Tuple import pytest import torch -import torch.nn.functional as F from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction From 51970828c1d57c74275e1eb8d2c87d6ab204c2ea Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 13 Nov 2024 14:57:12 +0800 Subject: [PATCH 3/6] Clean up bench Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_dpo_loss.py | 39 +++++++++++-------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index fa492ab9b..8593985ac 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -67,7 +67,7 @@ def forward(self, x, target): def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - B = input.x // 2 + B = input.x T = input.extra_benchmark_config["T"] H = input.extra_benchmark_config["H"] V = input.extra_benchmark_config["V"] @@ -85,14 +85,14 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) - # Input shape: [B*2, T, H] - _input = torch.randn(B * 2, T, H, requires_grad=True, dtype=dtype, device=device) - # Target shape: [B*2, T] - target = torch.randint(V, (B * 2, T), dtype=torch.long, device=device) + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) # Add ignore_index tokens to simulate padding - num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index def fwd(): @@ -114,7 +114,7 @@ def full(): def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - B = input.x // 2 + B = input.x T = input.extra_benchmark_config["T"] H = input.extra_benchmark_config["H"] V = input.extra_benchmark_config["V"] @@ -133,27 +133,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) - # Input shape: [B*2, T, H] - input1 = torch.randn( - B * 2, T, H, device=device, dtype=dtype - ) # .detach().clone().requires_grad_(True) - input2 = torch.randn( - B * 2, T, H, device=device, dtype=dtype - ) # .detach().clone().requires_grad_(True) + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) - # Target shape: [B*2, T] - target = torch.randint(0, V, (B * 2, T), device=device, dtype=torch.long) + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) # Add ignore_index tokens - num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index def fwd(): if provider == "liger": - return liger_dpo_loss(input1, target) + return liger_dpo_loss(_input, target) elif provider == "huggingface": - return torch_dpo_loss(input2, target) + return torch_dpo_loss(_input, target) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -165,7 +160,7 @@ def fwd(): y = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), - grad_to_none=[input1, input2], + grad_to_none=[_input], rep=100, quantiles=QUANTILES, ) From e95bda48446f161eb9f2a13d879676dba6033519 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 14 Nov 2024 15:31:10 +0800 Subject: [PATCH 4/6] Refactor Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_dpo_loss.py | 6 +- src/liger_kernel/alignment/dpo_loss.py | 301 ---------------------- src/liger_kernel/chunked_loss/dpo_loss.py | 120 +++++++++ test/alignment/test_dpo_loss.py | 73 ------ test/chunked_loss/test_dpo_loss.py | 219 ++++++++++++++++ 5 files changed, 343 insertions(+), 376 deletions(-) delete mode 100644 src/liger_kernel/alignment/dpo_loss.py create mode 100644 src/liger_kernel/chunked_loss/dpo_loss.py delete mode 100644 test/alignment/test_dpo_loss.py create mode 100644 test/chunked_loss/test_dpo_loss.py diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 8593985ac..7a48edb14 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -1,5 +1,6 @@ +from test.chunked_loss.test_dpo_loss import HF_DPO_Loss + import torch -import triton from utils import ( QUANTILES, SingleBenchmarkRunInput, @@ -9,7 +10,8 @@ run_benchmarks, ) -from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction +import triton +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction class TorchDPOLoss(torch.nn.Module): diff --git a/src/liger_kernel/alignment/dpo_loss.py b/src/liger_kernel/alignment/dpo_loss.py deleted file mode 100644 index 6d54cb45c..000000000 --- a/src/liger_kernel/alignment/dpo_loss.py +++ /dev/null @@ -1,301 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from triton import next_power_of_2 - -from liger_kernel.ops.utils import element_mul_kernel - -# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 -# Setting limit as 65536 for better performance due to less register spilling -MAX_FUSED_SIZE = 65536 // 2 - - -def dpo_loss(chosen_logps, rejected_logps, beta=0.1): - """ - Compute DPO loss (Direct Preference Optimization). - Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - beta (float): Temperature parameter for the DPO loss. - """ - logits_diff = (chosen_logps - rejected_logps) / beta - losses = -F.logsigmoid(logits_diff) - return losses.sum() - - -class LigerFusedLinearDPOFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - _input, - weight, - target, - bias=None, - ignore_index=-100, - beta=0.1, - compiled=True, - ): - """ - Fused linear layer with DPO (Direct Preference Optimization) loss. - Handles both the forward and backward pass of the final linear layer with DPO loss. - - Args: - _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - ignore_index (int): Index to ignore for loss computation. - beta (float): Temperature parameter for the DPO loss. - compiled (bool): Whether to use torch compile for chunk accumulation. - """ - # TODO: Tune CHUNK_SIZE to fully utilize the GPU - CHUNK_SIZE = 1 - - def _compute_dpo_loss(input_chunk, weight, target_chunk, bias=None): - len_chosen_chunk = target_chunk.shape[0] // 2 - - unnorm_logits = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - unnorm_logits = unnorm_logits + bias - unnorm_logits = unnorm_logits.float() - norm_logits = F.log_softmax(unnorm_logits, dim=-1) - - # Compute NLL loss for chosen responses - chosen_nll_loss = F.nll_loss( - norm_logits[:len_chosen_chunk].view(-1, norm_logits.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - chosen_nll_loss = ( - chosen_nll_loss / (target[: target.shape[0] // 2] != ignore_index).sum() - ) - - # Compute log probabilities for both chosen and rejected - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - per_token_logps = norm_logits.gather(-1, label_chunk.unsqueeze(-1)).squeeze( - -1 - ) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - - # Compute DPO loss - preference_loss = dpo_loss(chosen_logps, rejected_logps, beta=beta) - preference_loss = preference_loss / (target.shape[0] // 2) - - # Total loss combines NLL and DPO loss - loss = chosen_nll_loss + preference_loss - return loss, (preference_loss, chosen_logps, rejected_logps) - - def compute_dpo_loss(input_chunk, weight, target_chunk, bias=None): - return _compute_dpo_loss(input_chunk, weight, target_chunk, bias) - - grad_weight = torch.zeros_like(weight) - grad_chosen_inputs = [] - grad_rejected_inputs = [] - grad_bias = torch.zeros_like(bias) if bias is not None else None - loss_acc = torch.zeros((), device=_input.device) - - chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) - - def accumulate_chunk(input_chunk, target_chunk): - if bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, - (chunk_dpo_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value( - compute_dpo_loss, argnums=(0, 1, 3), has_aux=True - )( - input_chunk, weight, target_chunk, bias - ) - grad_bias.add_(chunk_grad_bias) - else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, - (chunk_dpo_loss, chunk_chosen_logps, chunk_rejected_logps), - ) = torch.func.grad_and_value( - compute_dpo_loss, argnums=(0, 1), has_aux=True - )( - input_chunk, weight, target_chunk - ) - grad_weight.add_(chunk_grad_weight) - loss_acc.add_(chunk_loss) - return chunk_grad_input - - len_chosen = target.shape[0] // 2 - _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) - _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) - _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) - _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) - - for ( - chosen_input_chunk, - rejected_input_chunk, - chosen_target_chunk, - rejected_target_chunk, - ) in zip( - _chosen_input_chunks, - _rejected_input_chunks, - _chosen_target_chunks, - _rejected_target_chunks, - ): - input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) - target_chunk = torch.cat( - [chosen_target_chunk, rejected_target_chunk], dim=0 - ) - - if compiled: - accumulate_chunk = torch.compile(accumulate_chunk) - grad_input = accumulate_chunk(input_chunk, target_chunk) - - grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) - grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :]) - - grad_inputs = grad_chosen_inputs + grad_rejected_inputs - - ctx.save_for_backward( - torch.cat(grad_inputs, dim=0), - grad_weight, - grad_bias, - ) - return loss_acc - - @staticmethod - def backward(ctx, grad_output): - grad_input, grad_weight, grad_bias = ctx.saved_tensors - if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): - BT, H = grad_input.view(-1, grad_input.shape[-1]).shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, next_power_of_2(H)) - - element_mul_kernel[(n_rows,)]( - grad_input, - grad_input.stride(-2), - grad_output, - H, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - if grad_weight is not None: - V, H = grad_weight.shape - n_rows = V - - element_mul_kernel[(n_rows,)]( - grad_weight, - grad_weight.stride(-2), - grad_output, - H, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - if grad_bias is not None: - V = grad_bias.shape[0] - n_rows = V - - element_mul_kernel[(n_rows,)]( - grad_bias, - grad_bias.stride(-1), - grad_output, - 1, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - return grad_input, grad_weight, None, grad_bias, None, None, None - - -class HF_DPO_Loss: - """ - Implementation of Direct Preference Optimization (DPO) loss, - adapted from the Hugging Face implementation. - Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py - """ - - def __init__(self, ignore_index: int = -100, beta: float = 0.1): - self.ignore_index = ignore_index - self.beta = beta - - def get_batch_logps( - self, - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - ) -> torch.FloatTensor: - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits and labels must have the same shape.") - - loss_mask = labels != self.ignore_index - labels = torch.where(labels == self.ignore_index, 0, labels) - per_token_logps = torch.gather( - logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) - ).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def dpo_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - ) -> torch.FloatTensor: - """Compute DPO loss for a batch of policy log probabilities.""" - logits_diff = (policy_chosen_logps - policy_rejected_logps) / self.beta - losses = -F.logsigmoid(logits_diff) - return losses - - def concatenated_forward( - self, - _input: torch.FloatTensor, - weight: torch.FloatTensor, - target: torch.LongTensor, - bias: torch.FloatTensor = None, - ): - len_chosen = _input.shape[0] // 2 - - outputs = _input @ weight.t() - if bias is not None: - outputs = outputs + bias - all_logits = outputs.float() - - def cross_entropy_loss(logits, labels): - loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], target[:len_chosen] - ) - - all_logps = self.get_batch_logps( - all_logits, - target, - average_log_prob=True, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - return chosen_logps, rejected_logps, chosen_nll_loss - - def get_batch_loss_metrics( - self, - _input: torch.FloatTensor, - weight: torch.FloatTensor, - target: torch.LongTensor, - bias: torch.FloatTensor = None, - ): - policy_chosen_logps, policy_rejected_logps, nll_loss = ( - self.concatenated_forward(_input, weight, target, bias) - ) - - dpo_losses = self.dpo_loss(policy_chosen_logps, policy_rejected_logps) - loss = nll_loss + dpo_losses.mean() - return loss diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py new file mode 100644 index 000000000..12b12c6ff --- /dev/null +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -0,0 +1,120 @@ +from functools import partial + +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +def dpo_loss(chosen_logps, rejected_logps, beta=0.1): + """ + Compute DPO loss (Direct Preference Optimization). + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the direct preference loss. + """ + logits_diff = beta * (chosen_logps - rejected_logps) + losses = -F.logsigmoid(logits_diff) + return losses.sum() + + +def _compute_dpo_loss( + input_chunk, + weight, + target_chunk, + bias=None, + full_target=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, +): + """ + Compute DPO loss for a chunk of input and target. + Args: + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the direct preference loss. + """ + + len_chosen_chunk = target_chunk.shape[0] // 2 + + logits_chunk = input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) # Normalize the unnorm_logits + + # Compute NLL loss for chosen responses + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + + # Compute log probabilities for both chosen and rejected + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + + # Compute DPO loss + preference_loss = dpo_loss(chosen_logps, rejected_logps, beta=beta) + preference_loss = preference_loss / (full_target.shape[0] // 2) + + # Total loss combines NLL and DPO loss + loss = chosen_nll_loss + preference_loss + return loss, (preference_loss, chosen_logps, rejected_logps) + + +class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + compiled=True, + ): + """ + Fused linear layer with DPO (Direct Preference Optimization) loss. + Handles both the forward and backward pass of the final linear layer with DPO loss. + Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. + """ + dpo_loss_fn = partial( + _compute_dpo_loss, + full_target=target, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + ) + return LigerFusedLinearPreferenceBase.forward( + ctx, _input, weight, target, bias, loss_fn=dpo_loss_fn + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None diff --git a/test/alignment/test_dpo_loss.py b/test/alignment/test_dpo_loss.py deleted file mode 100644 index d9fed0502..000000000 --- a/test/alignment/test_dpo_loss.py +++ /dev/null @@ -1,73 +0,0 @@ -from test.utils import assert_verbose_allclose, set_seed - -import pytest -import torch - -from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction - -# set random seed globally -set_seed() - - -@pytest.mark.parametrize( - "B, T, H, V", - [ - (8, 128, 1024, 4096), - (3, 47, 31, 123), # random shape - ], -) -@pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 1e-5, 5e-4), - ], -) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) -def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): - B = 2 * B # dpo loss requires B to be even - - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar - input1 = _input.detach().clone().requires_grad_(True) - input2 = _input.detach().clone().requires_grad_(True) - - target = torch.randint( - 0, - V, - ( - B, - T, - ), - device="cuda", - dtype=torch.long, - ) - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] - target.view(-1)[indices_to_assign] = ignore_index - - _weight = torch.randn(V, H, device="cuda", dtype=dtype) - weight1 = _weight.detach().clone().requires_grad_(True) - weight2 = _weight.detach().clone().requires_grad_(True) - - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None - bias1 = _bias.detach().clone().requires_grad_(True) if bias else None - bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - - loss1 = HF_DPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( - input1, weight1, target, bias1 - ) - loss2 = LigerFusedLinearDPOFunction.apply( - input2, weight2, target, bias2, ignore_index, beta, True - ) - - assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) - - loss1.backward() - loss2.backward() - - assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) - assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) - if bias: - assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py new file mode 100644 index 000000000..39c18c75e --- /dev/null +++ b/test/chunked_loss/test_dpo_loss.py @@ -0,0 +1,219 @@ +from test.utils import assert_verbose_allclose, set_seed +from typing import Tuple + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction + +# set random seed globally +set_seed() + + +class HF_DPO_Loss: + """ + Implementation of the Direct Preference Optimization (DPO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py + """ + + def __init__(self, ignore_index: int = -100, beta: float = 0.1): + self.ignore_index = ignore_index + self.beta = beta + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError( + "Logits (batch and sequence length dim) and labels must have the same shape." + ) + + loss_mask = labels != self.ignore_index + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == self.ignore_index, 0, labels) + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> torch.FloatTensor: + """Compute DPO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + The losses tensor contains the DPO loss for each example in the batch. + """ + # Derived from https://huggingface.co/papers/2305.18290 + logits_diff = self.beta * (policy_chosen_logps - policy_rejected_logps) + losses = -F.logsigmoid(logits_diff) + return losses + + def concatenated_forward( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + len_chosen = _input.shape[0] // 2 + + outputs = _input @ weight.t() + if bias is not None: + outputs = outputs + bias + all_logits = outputs.float() + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) + + all_logps = self.get_batch_logps( + all_logits, + target, + average_log_prob=True, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) + + def get_batch_loss_metrics( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + + forward_output = self.concatenated_forward(_input, weight, target, bias) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + + losses = self.dpo_loss(policy_chosen_logps, policy_rejected_logps) + # full DPO loss + loss = policy_nll_loss - losses.mean() + return loss + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): + B = 2 * B # dpo loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HF_DPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + input1, weight1, target, bias1 + ) + loss2 = LigerFusedLinearDPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, True + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) From a50d38b9ec64ac9f9af8213124220b54d721410a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 14 Nov 2024 15:54:05 +0800 Subject: [PATCH 5/6] Clean up: fmt & fix tol Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_dpo_loss.py | 2 +- src/liger_kernel/chunked_loss/dpo_loss.py | 8 +++++--- test/chunked_loss/test_dpo_loss.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 7a48edb14..537be47bc 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -1,6 +1,7 @@ from test.chunked_loss.test_dpo_loss import HF_DPO_Loss import torch +import triton from utils import ( QUANTILES, SingleBenchmarkRunInput, @@ -10,7 +11,6 @@ run_benchmarks, ) -import triton from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 12b12c6ff..005959f15 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -16,7 +16,7 @@ def dpo_loss(chosen_logps, rejected_logps, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). beta (float): Weight for the direct preference loss. """ - logits_diff = beta * (chosen_logps - rejected_logps) + logits_diff = beta * (chosen_logps - rejected_logps) losses = -F.logsigmoid(logits_diff) return losses.sum() @@ -42,13 +42,15 @@ def _compute_dpo_loss( ignore_index (int): Index to ignore for loss computation. beta (float): Weight for the direct preference loss. """ - + len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() # chunk_size x V if bias is not None: logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) # Normalize the unnorm_logits + log_probs_chunk = F.log_softmax( + logits_chunk.float(), dim=-1 + ) # Normalize the unnorm_logits # Compute NLL loss for chosen responses chosen_nll_loss = 0.0 diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 39c18c75e..0495fa723 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -76,7 +76,7 @@ def dpo_loss( logits_diff = self.beta * (policy_chosen_logps - policy_rejected_logps) losses = -F.logsigmoid(logits_diff) return losses - + def concatenated_forward( self, _input: torch.FloatTensor, @@ -155,6 +155,7 @@ def get_batch_loss_metrics( loss = policy_nll_loss - losses.mean() return loss + @pytest.mark.parametrize( "B, T, H, V", [ @@ -166,7 +167,7 @@ def get_batch_loss_metrics( "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 1e-5, 5e-4), + (1.0, torch.float32, 2e-2, 5e-1), ], ) @pytest.mark.parametrize("bias", [True, False]) From 854c1b3b3556ea3e23ac6b52a0bf982a175be415 Mon Sep 17 00:00:00 2001 From: shivam15s Date: Fri, 15 Nov 2024 01:21:56 +0000 Subject: [PATCH 6/6] align with interface --- src/liger_kernel/chunked_loss/dpo_loss.py | 109 +++++----------------- 1 file changed, 22 insertions(+), 87 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 005959f15..150cb9e1c 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -1,6 +1,3 @@ -from functools import partial - -import torch import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_preference import ( @@ -8,84 +5,21 @@ ) -def dpo_loss(chosen_logps, rejected_logps, beta=0.1): - """ - Compute DPO loss (Direct Preference Optimization). - Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - beta (float): Weight for the direct preference loss. - """ - logits_diff = beta * (chosen_logps - rejected_logps) - losses = -F.logsigmoid(logits_diff) - return losses.sum() - - -def _compute_dpo_loss( - input_chunk, - weight, - target_chunk, - bias=None, - full_target=None, - ignore_index=-100, - beta=0.1, - compute_nll_loss=True, -): - """ - Compute DPO loss for a chunk of input and target. - Args: - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - beta (float): Weight for the direct preference loss. - """ - - len_chosen_chunk = target_chunk.shape[0] // 2 - - logits_chunk = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax( - logits_chunk.float(), dim=-1 - ) # Normalize the unnorm_logits - - # Compute NLL loss for chosen responses - chosen_nll_loss = 0.0 - if compute_nll_loss: - chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - - # Compute log probabilities for both chosen and rejected - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] - - # Compute DPO loss - preference_loss = dpo_loss(chosen_logps, rejected_logps, beta=beta) - preference_loss = preference_loss / (full_target.shape[0] // 2) - - # Total loss combines NLL and DPO loss - loss = chosen_nll_loss + preference_loss - return loss, (preference_loss, chosen_logps, rejected_logps) +class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + """ + Compute DPO loss (Direct Preference Optimization). + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the direct preference loss. + """ + logits_diff = beta * (chosen_logps - rejected_logps) + losses = -F.logsigmoid(logits_diff) + return losses.sum() -class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def forward( ctx, @@ -101,17 +35,18 @@ def forward( """ Fused linear layer with DPO (Direct Preference Optimization) loss. Handles both the forward and backward pass of the final linear layer with DPO loss. - Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. """ - dpo_loss_fn = partial( - _compute_dpo_loss, - full_target=target, + return LigerFusedLinearPreferenceBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + target=target, + bias=bias, + loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, - compute_nll_loss=compute_nll_loss, - ) - return LigerFusedLinearPreferenceBase.forward( - ctx, _input, weight, target, bias, loss_fn=dpo_loss_fn + compiled=compiled, ) @staticmethod