Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify ref_input in chunked_loss base class and fix tests #470

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def forward(
weight,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
ignore_index=-100,
Expand All @@ -79,14 +80,15 @@ def forward(
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None, None, None, None
return *grads, None, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand Down Expand Up @@ -118,13 +120,21 @@ def __init__(
self.use_ref_model = use_ref_model

def forward(
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
self,
lin_weight,
_input,
target,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
ref_input,
ref_weight,
ref_bias,
self.ignore_index,
Expand Down
51 changes: 42 additions & 9 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def forward(
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
# TODO: ref input
ref_input=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -97,20 +97,26 @@ def forward(
**loss_kwargs,
)

def fused_fwd_bwd(input_chunk, target_chunk):
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
if bias is not None:
return torch.func.grad_and_value(
compute_loss, argnums=(0, 1, 3), has_aux=True
)(input_chunk, weight, target_chunk, bias)
)(
input_chunk,
weight,
target_chunk,
bias,
ref_input_chunk=ref_input_chunk,
)
else:
return torch.func.grad_and_value(
compute_loss, argnums=(0, 1), has_aux=True
)(input_chunk, weight, target_chunk)
)(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)

def accumulate_chunk(input_chunk, target_chunk):
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
Expand All @@ -122,7 +128,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = fused_fwd_bwd(input_chunk, target_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(chunk_grad_input, chunk_grad_weight), (
Expand All @@ -135,7 +141,7 @@ def accumulate_chunk(input_chunk, target_chunk):
chunk_nll_loss,
*aux_outputs,
),
) = fused_fwd_bwd(input_chunk, target_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)

# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
Expand Down Expand Up @@ -182,18 +188,43 @@ def accumulate_chunk(input_chunk, target_chunk):
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

if use_ref_model:
_ref_chosen_input_chunks = torch.chunk(
ref_input[:len_chosen], chunks=chunks, dim=0
)
_ref_rejected_input_chunks = torch.chunk(
ref_input[len_chosen:], chunks=chunks, dim=0
)

for (
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
ref_chosen_input_chunk,
ref_rejected_input_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
(
_ref_chosen_input_chunks
if use_ref_model
else [None] * len(_chosen_input_chunks)
),
(
_ref_rejected_input_chunks
if use_ref_model
else [None] * len(_rejected_input_chunks)
),
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
ref_input_chunk = (
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
if use_ref_model
else None
)
target_chunk = torch.cat(
[chosen_target_chunk, rejected_target_chunk], dim=0
)
Expand All @@ -202,9 +233,10 @@ def accumulate_chunk(input_chunk, target_chunk):
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None

# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk)
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
Expand Down Expand Up @@ -301,6 +333,7 @@ def _compute_loss(
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
Expand Down Expand Up @@ -357,7 +390,7 @@ def _compute_loss(
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
Expand Down
36 changes: 28 additions & 8 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,15 @@ def __init__(
ignore_index=ignore_index, beta=beta, use_ref_model=True
).get_batch_loss_metrics

def forward(self, x, y):
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)


Expand All @@ -103,9 +109,15 @@ def __init__(
ignore_index=ignore_index, beta=beta, use_ref_model=True
)

def forward(self, x, y):
def forward(self, x, ref_x, y):
return self.dpo_loss(
self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)


Expand Down Expand Up @@ -170,6 +182,10 @@ def test_correctness(
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

ref_input = (
torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
)

target = torch.randint(
0,
V,
Expand All @@ -185,8 +201,8 @@ def test_correctness(
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, target)
loss1, aggregated_aux_outputs1 = torch_lm_head_dpo(input1, ref_input, target)
loss2, aggregated_aux_outputs2 = liger_lm_head_dpo(input2, ref_input, target)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -242,6 +258,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

ref_input = (
torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar
)

target = torch.randint(
0,
V,
Expand Down Expand Up @@ -270,10 +290,10 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None

loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
input1, weight1, target, bias1, ref_weight1, ref_bias1
input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
input2, weight2, target, bias2, ref_weight2, ref_bias2
input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2
)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
Expand Down
3 changes: 2 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def get_batch_loss_metrics(
_input: torch.FloatTensor,
target: torch.LongTensor,
bias: torch.FloatTensor = None,
ref_input: torch.FloatTensor = None,
ref_weight: torch.FloatTensor = None,
ref_bias: torch.FloatTensor = None,
average_log_prob: bool = True,
Expand All @@ -498,7 +499,7 @@ def get_batch_loss_metrics(
loss_kwargs = {}
if self.use_ref_model:
ref_chosen_logps, ref_rejected_logps = self.get_ref_logps(
_input, ref_weight, target, ref_bias, average_log_prob
ref_input, ref_weight, target, ref_bias, average_log_prob
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
Loading