Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ORPO Trainer + support HF metrics directly from chunked loss functions + fixes to avoid torch compile recompilations #429

Merged
merged 12 commits into from
Dec 6, 2024
26 changes: 26 additions & 0 deletions examples/alignment/accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
42 changes: 42 additions & 0 deletions examples/alignment/run_orpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer # noqa: F401

from liger_kernel.transformers import LigerORPOTrainer # noqa: F401

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
max_length=512,
padding="max_length",
)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = load_dataset("trl-lib/tldr-preference", split="train")

# train_dataset = train_dataset.map(
# lambda example: {
# "prompt": example["prompt"],
# "chosen": example["chosen"][0]["content"],
# "rejected": example["rejected"][0]["content"],
# }
# )
training_args = ORPOConfig(
output_dir="Llama3.2_1B_Instruct",
beta=0.1,
max_length=128,
per_device_train_batch_size=32,
max_steps=100,
save_strategy="no",
)

trainer = LigerORPOTrainer(
model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset
)

trainer.train()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ transformers = [

dev = [
"transformers>=4.44.2",
"trl>=0.11.0",
"matplotlib>=3.7.2",
"flake8>=4.0.1.1",
"black>=24.4.2",
Expand Down
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
"""
Compute odds-ratio loss.
Args:
Expand All @@ -18,7 +18,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
beta (float): Weight for the odds ratio loss.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).mean()
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
return loss

@staticmethod
Expand Down Expand Up @@ -55,7 +55,7 @@ def forward(
)

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
Expand Down
7 changes: 4 additions & 3 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
def preference_loss_fn(
chosen_logps,
rejected_logps,
full_target,
ref_chosen_logps=None,
ref_rejected_logps=None,
beta=0.1,
Expand All @@ -34,8 +35,8 @@ def preference_loss_fn(
rejected_logratios = rejected_logps - ref_rejected_logps

logits_diff = beta * (chosen_logratios - rejected_logratios)
losses = -F.logsigmoid(logits_diff)
return losses.sum()
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
return loss

@staticmethod
def forward(
Expand Down Expand Up @@ -73,7 +74,7 @@ def forward(
)

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, *grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
Expand Down
190 changes: 146 additions & 44 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,17 @@ def chunk_forward(

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
return chosen_logps, rejected_logps, chosen_nll_loss

chosen_logits = logits_chunk[:len_chosen_chunk]
rejected_logits = logits_chunk[len_chosen_chunk:]

return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
)

@staticmethod
def forward(
Expand Down Expand Up @@ -103,6 +113,12 @@ def forward(
grad_rejected_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_acc = torch.zeros((), device=_input.device)
policy_chosen_logps = []
policy_rejected_logps = []
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
policy_rejected_logits_mean = torch.zeros((), device=_input.device)
policy_nll_loss = torch.zeros((), device=_input.device)
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks

loss_func_to_call = partial(
LigerFusedLinearPreferenceBase._compute_loss,
Expand All @@ -118,32 +134,72 @@ def forward(
**loss_kwargs,
)

def accumulate_helper(input_chunk, target_chunk):
if bias is not None:
return torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
)(input_chunk, weight, target_chunk, bias)
else:
return torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(input_chunk, weight, target_chunk)

def accumulate_chunk(input_chunk, target_chunk):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
)(
input_chunk, weight, target_chunk, bias
)
grad_bias.add_(chunk_grad_bias)
(
chunk_chosen_logps,
chunk_rejected_logps,
chunk_chosen_logits_mean,
chunk_rejected_logits_mean,
chunk_nll_loss,
*aux_outputs,
),
) = accumulate_helper(input_chunk, target_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(chunk_grad_input, chunk_grad_weight), (
chunk_loss,
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
input_chunk, weight, target_chunk
)
(
chunk_chosen_logps,
chunk_rejected_logps,
chunk_chosen_logits_mean,
chunk_rejected_logits_mean,
chunk_nll_loss,
*aux_outputs,
),
) = accumulate_helper(input_chunk, target_chunk)

grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
policy_chosen_logps.append(chunk_chosen_logps)
policy_rejected_logps.append(chunk_rejected_logps)
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
policy_nll_loss.add_(chunk_nll_loss)

# Initialize storage for aux_outputs
if len(aggregated_aux_outputs) == 0:
for aux in aux_outputs:
if aux.ndim == 0:
aggregated_aux_outputs.append(
torch.zeros((), device=aux.device)
)
else:
aggregated_aux_outputs.append([])

# Process each aux_output
for i, aux in enumerate(aux_outputs):
if aux.ndim == 0:
aggregated_aux_outputs[i].add_(aux)
else:
aggregated_aux_outputs[i].append(aux)

return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)
accumulate_helper = torch.compile(accumulate_helper)

len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
Expand All @@ -168,28 +224,50 @@ def accumulate_chunk(input_chunk, target_chunk):
[chosen_target_chunk, rejected_target_chunk], dim=0
)

# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)

# accumulate loss, gradients, and metrics
grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
policy_chosen_logps = torch.cat(policy_chosen_logps, dim=0)
policy_rejected_logps = torch.cat(policy_rejected_logps, dim=0)

# Aggregate aux outputs lists into tensors
for i, aux in enumerate(aggregated_aux_outputs):
if isinstance(aux, list):
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return loss_acc
return_vars = (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
)
return loss_acc, (*return_vars, *aggregated_aux_outputs)

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, *grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None
if torch.ne(
grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
):
grad_input = grad_input * grad_output[0][0]
grad_weight = grad_weight * grad_output[0][0]
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None

Expand Down Expand Up @@ -228,40 +306,64 @@ def _compute_loss(
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
chosen_logps, rejected_logps, chosen_nll_loss = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
(
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)
chosen_logits_mean = chosen_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
)

if use_ref_model:
with torch.no_grad():
ref_chosen_logps, ref_rejected_logps, _ = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False,
)
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
) = LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

alignment_loss = preference_loss_fn(
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
preference_loss_outputs = preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
)
alignment_loss = alignment_loss / (full_target.shape[0] // 2)
if isinstance(preference_loss_outputs, tuple):
preference_loss, *aux_outputs = preference_loss_outputs
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - alignment_loss
return loss, (alignment_loss, chosen_logps, rejected_logps)
loss = alpha * chosen_nll_loss - preference_loss
return_vars = (
chosen_logps,
rejected_logps,
chosen_logits_mean,
rejected_logits_mean,
chosen_nll_loss,
)
return loss, (*return_vars, *aux_outputs)
Loading
Loading