diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index bec3d6e19..5f1b17cf5 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -59,6 +59,7 @@ def forward( weight, target, bias=None, + ref_input=None, ref_weight=None, ref_bias=None, ignore_index=-100, @@ -79,6 +80,7 @@ def forward( compute_nll_loss=compute_nll_loss, compiled=compiled, use_ref_model=use_ref_model, + ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, ) @@ -86,7 +88,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -118,13 +120,21 @@ def __init__( self.use_ref_model = use_ref_model def forward( - self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None + self, + lin_weight, + _input, + target, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, ): return LigerFusedLinearDPOFunction.apply( _input, lin_weight, target, bias, + ref_input, ref_weight, ref_bias, self.ignore_index, diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 57afabc80..fff0791ec 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -29,7 +29,7 @@ def forward( compute_nll_loss=True, compiled=True, use_ref_model=False, - # TODO: ref input + ref_input=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -97,20 +97,26 @@ def forward( **loss_kwargs, ) - def fused_fwd_bwd(input_chunk, target_chunk): + def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): """ Fused forward and backward pass for a chunk of input and target. """ if bias is not None: return torch.func.grad_and_value( compute_loss, argnums=(0, 1, 3), has_aux=True - )(input_chunk, weight, target_chunk, bias) + )( + input_chunk, + weight, + target_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) else: return torch.func.grad_and_value( compute_loss, argnums=(0, 1), has_aux=True - )(input_chunk, weight, target_chunk) + )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk) - def accumulate_chunk(input_chunk, target_chunk): + def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( chunk_loss, @@ -122,7 +128,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = fused_fwd_bwd(input_chunk, target_chunk) + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: (chunk_grad_input, chunk_grad_weight), ( @@ -135,7 +141,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = fused_fwd_bwd(input_chunk, target_chunk) + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) # Accumulate gradients grad_weight.add_(chunk_grad_weight) @@ -182,18 +188,43 @@ def accumulate_chunk(input_chunk, target_chunk): _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + if use_ref_model: + _ref_chosen_input_chunks = torch.chunk( + ref_input[:len_chosen], chunks=chunks, dim=0 + ) + _ref_rejected_input_chunks = torch.chunk( + ref_input[len_chosen:], chunks=chunks, dim=0 + ) + for ( chosen_input_chunk, rejected_input_chunk, chosen_target_chunk, rejected_target_chunk, + ref_chosen_input_chunk, + ref_rejected_input_chunk, ) in zip( _chosen_input_chunks, _rejected_input_chunks, _chosen_target_chunks, _rejected_target_chunks, + ( + _ref_chosen_input_chunks + if use_ref_model + else [None] * len(_chosen_input_chunks) + ), + ( + _ref_rejected_input_chunks + if use_ref_model + else [None] * len(_rejected_input_chunks) + ), ): input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) + ref_input_chunk = ( + torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) + if use_ref_model + else None + ) target_chunk = torch.cat( [chosen_target_chunk, rejected_target_chunk], dim=0 ) @@ -202,9 +233,10 @@ def accumulate_chunk(input_chunk, target_chunk): torch._dynamo.mark_dynamic(input_chunk, 1) torch._dynamo.mark_dynamic(target_chunk, 1) torch._dynamo.mark_dynamic(target, 1) + torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None # accumulate loss, gradients, and metrics - accumulate_chunk(input_chunk, target_chunk) + accumulate_chunk(input_chunk, target_chunk, ref_input_chunk) # combine grad_chosen_inputs and grad_rejected_inputs grad_inputs = grad_chosen_inputs + grad_rejected_inputs @@ -301,6 +333,7 @@ def _compute_loss( beta=0.1, compute_nll_loss=True, use_ref_model=False, + ref_input_chunk=None, ref_weight=None, ref_bias=None, **loss_kwargs, @@ -357,7 +390,7 @@ def _compute_loss( ref_rejected_logits, ref_chosen_nll_loss, ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, + ref_input_chunk, ref_weight, target_chunk, ref_bias, diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0dba17df8..0ac8faeb8 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -75,9 +75,15 @@ def __init__( ignore_index=ignore_index, beta=beta, use_ref_model=True ).get_batch_loss_metrics - def forward(self, x, y): + def forward(self, x, ref_x, y): return self.dpo_loss( - self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, ) @@ -103,9 +109,15 @@ def __init__( ignore_index=ignore_index, beta=beta, use_ref_model=True ) - def forward(self, x, y): + def forward(self, x, ref_x, y): return self.dpo_loss( - self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias + self.lin.weight, + x, + y, + self.lin.bias, + ref_x, + self.ref_lin.weight, + self.ref_lin.bias, ) @@ -170,6 +182,10 @@ def test_correctness( input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) + ref_input = ( + torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + ) + target = torch.randint( 0, V, @@ -185,8 +201,8 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, target) + loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target) + loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) @@ -242,6 +258,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) + ref_input = ( + torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + ) + target = torch.randint( 0, V, @@ -270,10 +290,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( - input1, weight1, target, bias1, ref_weight1, ref_bias1 + input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1 ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( - input2, weight2, target, bias2, ref_weight2, ref_bias2 + input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2 ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index ef2adbf2b..3d3799ad0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -478,6 +478,7 @@ def get_batch_loss_metrics( _input: torch.FloatTensor, target: torch.LongTensor, bias: torch.FloatTensor = None, + ref_input: torch.FloatTensor = None, ref_weight: torch.FloatTensor = None, ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, @@ -498,7 +499,7 @@ def get_batch_loss_metrics( loss_kwargs = {} if self.use_ref_model: ref_chosen_logps, ref_rejected_logps = self.get_ref_logps( - _input, ref_weight, target, ref_bias, average_log_prob + ref_input, ref_weight, target, ref_bias, average_log_prob ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps