Skip to content

Commit

Permalink
[Transformer] fix ORPO loss for MOE models (#479)
Browse files Browse the repository at this point in the history
## Summary
Add missing MOE loss when specified in the trainer.

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
kashif authored Dec 16, 2024
1 parent 0bb6c72 commit 21baccc
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/liger_kernel/transformers/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
outputs.last_hidden_state,
concatenated_batch["concatenated_labels"],
)
# if aux_loss_enabled, add the aux_loss to the orpo_loss
if self.aux_loss_enabled:
orpo_loss += self.aux_loss_coef * outputs.aux_loss

return orpo_loss, aux_outputs

def get_batch_loss_metrics(
Expand Down

0 comments on commit 21baccc

Please sign in to comment.