-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Proximal Policy Optimisation #812
Comments
@SalmanMohammadi thanks so much for the high quality RFC. PPO would be an amazing technique to add to torchtune! Overall the plan looks good. A quick comment on the model itself:
The description here would lend very well to the concepts we have in torchtune.
Based on what you described, I'd imagine that you would build a ppo model by adding a custom component builder which keeps most of the architecture the same but replaces the output layer with what you have in mind. Does this generally make sense? Happy to answer more questions on this. I'd need some more details on the implementation since there's a lot going on here, but I think these would be best communicated in the form of a prototype that does what you had in mind. I'm also cc-ing @vmoens who's the RL expert in PyTorch for his thoughts and feedback! |
Thanks so much for your feedback @kartikayk. I think it makes sense to start with the reward model implementation. There's a pre-trained reward model for Mistral-7B. Implementing component and model builders for Mistral to start could allow for easy testing. There might need to be some small modifications to In HuggingFace, reward models inherit the Writing my thought process below, I wonder if it makes sense to add a
Then, a corresponding component and model would be:
Thank you again for your help and feedback. It's super interesting and fun to be contributing. SidenoteIt probably wouldn't be much more effort to add support for training reward models once we implement reward models in Torchtune directly. We could probably use the |
This is awesome, @SalmanMohammadi! Seems like you have a lot of the components figured out! A few comments from my side:
This is great! Currently we have logic in the checkpointer which does the state_dict conversion. My first thought would be that you can just create a branch here for reward models by using the
This is perfect! I'd do pretty much exactly this. There might be some small nuances which we catch once you have code, but this looks great to me. You alluded to this, but one thing which would be great to do is to verify correctness by running a random input through yours and some reference implementation and comparing the output tensors,. This should give us confidence in the numerical equivalency and will help other folks use the module with high confidence. Let me know if this makes sense.
100% agreed on this. I'd love collaborate on adding this if you'd be interested. My initial thought here is that this shouldn't be too complicated too add. What do you think? I also saw you had another question about MLP implementations and why these are copied over :) I think it was a great question. Generally, we've tried to decouple the builders for different models as much as possible. This does lead to copy pasting some code, but generally makes things easy to handle, maintain, extend and ultimately deprecate. If you try to squeeze in too many things into a single implementation, ultimately those become bloated and full of conditionals. This makes any sort of extensions or refactors hard. Over time, we may find opportunities to consolidate and merge code - but thats an easier operation than splitting things to prevent complexity from increasing since this will likely break tons of users. Hope this makes sense. Happy to answer more questions! |
I love the support and positivity @kartikayk :) I've put a PR up for a (hopefully) pretty lightweight and non-invasive
I think it should be pretty straightforward to add a recipe for training this classifier on a reward modelling task! I'd be happy to hear your thoughts on anything that's missing. I mentioned in the PR that we could start with a recipe using the classifier and the dataset that was implemented for DPO.
I ended up answering my own question after reading the codebase. It's great to hear your thoughts. There's always a little SWE inside of me complaining about code duplication : ) I think the other advantage of the kind of low-coupling, high-modularity code you mentioned is interpretability. I could easily figure out where the implementation details were for an architecture I was interested in. This is imo a seriously underrated feature of an open-source, popular ML codebase. It makes a huge difference to every level of expertise of user, and particularly users coming from a non-SWE background who want to understand how things work on a more technical level. Next steps |
Awesome, love the PR @SalmanMohammadi! I'll review in a bit, but see that you already have a discussion going!
You hit the nail on the head. This was exactly the intent, and I'm glad it resonates. It's one of the design principles we did have much discussion and debate on :)
I've been using runpod for my own development and testing. Let me know if this works for you? Of course we'de be happy to do some testing on larger models as well and share all of the learnings and observations with you as well. This is really exciting! Thanks for helping shape this up. I'm looking forward to sharing this with the community :) |
@kartikayk the TransformerClassifier PR is pretty much good to go. Would you still like to collaborate on the RLHF process? There's a lot of steps and I have some design docs I could share on the different components we need. Happy to chat here or on Discord to share some of my draft ideas! |
@SalmanMohammadi I'm still very interested in the actual training! We can create a sidebar on discord to chat about this so other interested folks can follow along as well. WDYT? |
Sounds good! Let me know what you're interested in and I can share my thoughts/updates on what I'm working on. Let's chat more on Discord. |
Sounds good! Mind sharing your discord handle? :) |
Implementing Proximal Policy Optimisation
I've used some of the PyTorch RFC template here for clarity.
Authors:
Summary
I'd like to add support for fine-tuning models using the Proximal Policy Optimisation (PPO) reinforcement learning (RL) algorithm. Similar to Direct Policy Optimisation, PPO is a core component in Reinforcement Learning from Human Feedback (RLHF) for aligning language models.
PPO optimises a language model which acts as a policy with an action space equal to the model's vocabulary, and where the observation space is the distribution over all possible prompts, and the reward is some scalar value indicating the "preference" of the model's completion for a given prompt (the reward is usually given by a reward model calibrated for human preferences).
Motivation
This repository helps make a fascinating technology even more accessible. Supporting PPO will help users to understand and explore LLM alignment techniques in native PyTorch, which is already widely adopted and easy to get started with.
Proposed Implementation
recipes/
.Prior art
TRL implements a generalised PPO trainer. A policy is defined using a thin wrapper around a pre-trained LLM and adds a value function head to be optimised during PPO training. A copy of the model being trained is also initialised and frozen as a reference model.
Feedback and thoughts are very much appreciated. I'm hoping to add value here and I'm grateful for any guidance to help me do so.
The text was updated successfully, but these errors were encountered: