Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unpackaging error in Custom Mixture of Experts model when aux_loss_enabled is set to True. #2039

Merged
merged 4 commits into from
Sep 9, 2024

Conversation

Jonathanjordan21
Copy link
Contributor

@Jonathanjordan21 Jonathanjordan21 commented Sep 9, 2024

What does this PR do?

This PR fixes #2038.

Fix unpackaging error due to additional aux_loss returned by concatenated_forward function when aux_loss_enabled=True.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…d by **concatenated_forward** function when **aux_loss_enabled** is set to True.
@kashif
Copy link
Collaborator

kashif commented Sep 9, 2024

thanks @Jonathanjordan21

perhaps its more ellegant to do:

reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.ref_model, padded_batch)[:2]

what do you think?

@Jonathanjordan21
Copy link
Contributor Author

@kashif seems good. I actually just followed the earlier code which calculate the policy losses in get_batch_loss_metrics function.

1440        forward_output = self.concatenated_forward(model, batch)
1441        (
1442            policy_chosen_logps,
1443            policy_rejected_logps,
1444            policy_chosen_logits,
1445            policy_rejected_logits,
1446            policy_nll_loss,
1447       ) = forward_output[:5]
1448       if self.aux_loss_enabled:
1449            aux_loss = forward_output[5]

@kashif
Copy link
Collaborator

kashif commented Sep 9, 2024

yeah... i should have just done the above but happy if you do it!

@kashif
Copy link
Collaborator

kashif commented Sep 9, 2024

you might need to run pre-commit run --all-files in the root of the TRL folder fix any formatting issues

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif merged commit 72f19c3 into huggingface:main Sep 9, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DPOTrainer failed on training Custom Mixture of Experts model with config output_router_logits=True
3 participants