Skip to content

Commit

Permalink
modify ref_input in chunked_loss base class and fix tests (#470)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
modify ref_input in pref_loss and pass tests. Aims to fix #447 
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
  • Loading branch information
shivam15s and ByronHsu authored Dec 12, 2024
1 parent 6c68bcb commit 55e3755
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 20 deletions.
14 changes: 12 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def forward(
weight,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
ignore_index=-100,
Expand All @@ -79,14 +80,15 @@ def forward(
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand Down Expand Up @@ -118,13 +120,21 @@ def __init__(
self.use_ref_model = use_ref_model

def forward(
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
self,
lin_weight,
_input,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
ref_input,
ref_weight,
ref_bias,
self.ignore_index,
Expand Down
51 changes: 42 additions & 9 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
# TODO: ref input
ref_input=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -97,20 +97,26 @@ def forward(
**loss_kwargs,
)

def fused_fwd_bwd(input_chunk, target_chunk):
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
if bias is not None:
return torch.func.grad_and_value(
compute_loss, argnums=(0, 1, 3), has_aux=True
)(input_chunk, weight, target_chunk, bias)
)(
input_chunk,
weight,
target_chunk,
bias,
ref_input_chunk=ref_input_chunk,
)
else:
return torch.func.grad_and_value(
compute_loss, argnums=(0, 1), has_aux=True
)(input_chunk, weight, target_chunk)
)(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)

def accumulate_chunk(input_chunk, target_chunk):
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
Expand All @@ -122,7 +128,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = fused_fwd_bwd(input_chunk, target_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(chunk_grad_input, chunk_grad_weight), (
Expand All @@ -135,7 +141,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = fused_fwd_bwd(input_chunk, target_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)

# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
Expand Down Expand Up @@ -182,18 +188,43 @@ def accumulate_chunk(input_chunk, target_chunk):
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

if use_ref_model:
_ref_chosen_input_chunks = torch.chunk(
ref_input[:len_chosen], chunks=chunks, dim=0
)
_ref_rejected_input_chunks = torch.chunk(
ref_input[len_chosen:], chunks=chunks, dim=0
)

for (
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
ref_chosen_input_chunk,
ref_rejected_input_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
(
_ref_chosen_input_chunks
if use_ref_model
else [None] * len(_chosen_input_chunks)
),
(
_ref_rejected_input_chunks
if use_ref_model
else [None] * len(_rejected_input_chunks)
),
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
ref_input_chunk = (
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
if use_ref_model
else None
)
target_chunk = torch.cat(
[chosen_target_chunk, rejected_target_chunk], dim=0
)
Expand All @@ -202,9 +233,10 @@ def accumulate_chunk(input_chunk, target_chunk):
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None

# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk)
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
Expand Down Expand Up @@ -301,6 +333,7 @@ def _compute_loss(
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -357,7 +390,7 @@ def _compute_loss(
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
Expand Down
36 changes: 28 additions & 8 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,15 @@ def __init__(
ignore_index=ignore_index, beta=beta, use_ref_model=True
).get_batch_loss_metrics

def forward(self, x, y):
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)


Expand All @@ -103,9 +109,15 @@ def __init__(
ignore_index=ignore_index, beta=beta, use_ref_model=True
)

def forward(self, x, y):
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)


Expand Down Expand Up @@ -170,6 +182,10 @@ def test_correctness(
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

ref_input = (
torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
)

target = torch.randint(
0,
V,
Expand All @@ -185,8 +201,8 @@ def test_correctness(
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, target)
loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -242,6 +258,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

ref_input = (
torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
)

target = torch.randint(
0,
V,
Expand Down Expand Up @@ -270,10 +290,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None

loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
input1, weight1, target, bias1, ref_weight1, ref_bias1
input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
input2, weight2, target, bias2, ref_weight2, ref_bias2
input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2
)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
Expand Down
3 changes: 2 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def get_batch_loss_metrics(
_input: torch.FloatTensor,
target: torch.LongTensor,
bias: torch.FloatTensor = None,
ref_input: torch.FloatTensor = None,
ref_weight: torch.FloatTensor = None,
ref_bias: torch.FloatTensor = None,
average_log_prob: bool = True,
Expand All @@ -498,7 +499,7 @@ def get_batch_loss_metrics(
loss_kwargs = {}
if self.use_ref_model:
ref_chosen_logps, ref_rejected_logps = self.get_ref_logps(
_input, ref_weight, target, ref_bias, average_log_prob
ref_input, ref_weight, target, ref_bias, average_log_prob
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down

0 comments on commit 55e3755

Please sign in to comment.