Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Proximal Policy Optimisation #812

Closed
SalmanMohammadi opened this issue Apr 19, 2024 · 9 comments · Fixed by #1005
Closed

[RFC] Proximal Policy Optimisation #812

SalmanMohammadi opened this issue Apr 19, 2024 · 9 comments · Fixed by #1005
Assignees
Labels
rfc Request for comments

Comments

@SalmanMohammadi
Copy link
Collaborator

Implementing Proximal Policy Optimisation

I've used some of the PyTorch RFC template here for clarity.

Authors:

Summary

I'd like to add support for fine-tuning models using the Proximal Policy Optimisation (PPO) reinforcement learning (RL) algorithm. Similar to Direct Policy Optimisation, PPO is a core component in Reinforcement Learning from Human Feedback (RLHF) for aligning language models.

PPO optimises a language model which acts as a policy with an action space equal to the model's vocabulary, and where the observation space is the distribution over all possible prompts, and the reward is some scalar value indicating the "preference" of the model's completion for a given prompt (the reward is usually given by a reward model calibrated for human preferences).

Motivation

This repository helps make a fascinating technology even more accessible. Supporting PPO will help users to understand and explore LLM alignment techniques in native PyTorch, which is already widely adopted and easy to get started with.

Proposed Implementation

  • The algorithm itself could be implemented as a recipe in recipes/.
  • I don't think datasets need to be in a specific format when using a reward model, so existing dataset functionality can be used.
  • Integration of reward models into the codebase: Would this require reward model implementation, if the authors of this repo would like all models to be in native PyTorch? In practice, reward models are (sometimes smaller parameter) copies of the model being fine tuned, so it could be as simple as inheriting from current model implementations and adapting the last layer to output a scalar reward.

Prior art

TRL implements a generalised PPO trainer. A policy is defined using a thin wrapper around a pre-trained LLM and adds a value function head to be optimised during PPO training. A copy of the model being trained is also initialised and frozen as a reference model.

Feedback and thoughts are very much appreciated. I'm hoping to add value here and I'm grateful for any guidance to help me do so.

@kartikayk
Copy link
Contributor

@SalmanMohammadi thanks so much for the high quality RFC. PPO would be an amazing technique to add to torchtune!

Overall the plan looks good. A quick comment on the model itself:

Integration of reward models into the codebase

The description here would lend very well to the concepts we have in torchtune.

  • Component Builders. For each model, we have component_builders which stitch together the modules to build the overall architecture. For example, the Llama3 component builders can be found here. This includes the llama3 and lora_llama3 model.
  • Model Builders. Once the arch is ready, we create speciic instantiations of the architecture by using the right hyperparams. For example, the llama3 component builder is used to create the llama3_8b and lama3_70b model builder here.

Based on what you described, I'd imagine that you would build a ppo model by adding a custom component builder which keeps most of the architecture the same but replaces the output layer with what you have in mind. Does this generally make sense? Happy to answer more questions on this.

I'd need some more details on the implementation since there's a lot going on here, but I think these would be best communicated in the form of a prototype that does what you had in mind.

I'm also cc-ing @vmoens who's the RL expert in PyTorch for his thoughts and feedback!

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 20, 2024

Thanks so much for your feedback @kartikayk.

I think it makes sense to start with the reward model implementation. There's a pre-trained reward model for Mistral-7B. Implementing component and model builders for Mistral to start could allow for easy testing. There might need to be some small modifications to convert_weights.py to support loading reward models.

In HuggingFace, reward models inherit the AutoModelForSequenceClassification generic. This is just some sequence model which has a linear classification layer (example for Mistral7B) slapped on top of the final hidden state from the underlying Seq2Seq model.

Writing my thought process below, I wonder if it makes sense to add a TransformerClassifier in transformer.py, with a forward that looks something like:

class TransformerClassifier(nn.Module):
    def __init__(transformer_decoder: TransformerDecoder, embed_dim: int, n_classes: int):
        ...
        self.score = nn.Linear(embed_dim, num_labels)
    ...
    def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
    """
        Args:
            decoder_output (Tensor): TransformerDecoder output with shape [b x s x v]
        Returns:
            Tensor: Preferences/rewards  with shape [b x 1]
    """
        transformer_output = self.transformer_decoder(tokens, input_pos=input_pos)
        ....
        score = self.score(transformer_output)
        # return logits / apply act / etc.
        
        
        
  

Then, a corresponding component and model would be:

# in _component_builders.py
def mistral_classifier(embed_dim, n_classes, **mistral_args) -> TransformerClassifier:
    transformer_decoder = mistral(mistral_args)
    return TransformerClassifier(transformer_decoder, embed_dim, n_classes)

# in _model_builders.py
def mistral_7b_classifier() -> TransformerClassifier:
    ...

Thank you again for your help and feedback. It's super interesting and fun to be contributing.

Sidenote

It probably wouldn't be much more effort to add support for training reward models once we implement reward models in Torchtune directly. We could probably use the PreferenceDataset that was implemented for DPO. I suppose it's technically a form of fine-tuning, so might be in scope of this library. It'd be really nice to allow users to go through the full RLHF process in native torch.

@kartikayk
Copy link
Contributor

This is awesome, @SalmanMohammadi! Seems like you have a lot of the components figured out!

A few comments from my side:

There might need to be some small modifications to convert_weights.py to support loading reward models

This is great! Currently we have logic in the checkpointer which does the state_dict conversion. My first thought would be that you can just create a branch here for reward models by using the model_type field. I'm not sure how general these might be so maybe we can start with something like MISTRAL_REWARD_MODEL and extend when we add more models? Let me know if that makes sense.

Writing my thought process below

This is perfect! I'd do pretty much exactly this. There might be some small nuances which we catch once you have code, but this looks great to me. You alluded to this, but one thing which would be great to do is to verify correctness by running a random input through yours and some reference implementation and comparing the output tensors,. This should give us confidence in the numerical equivalency and will help other folks use the module with high confidence. Let me know if this makes sense.

It'd be really nice to allow users to go through the full RLHF process in native torch.

100% agreed on this. I'd love collaborate on adding this if you'd be interested. My initial thought here is that this shouldn't be too complicated too add. What do you think?

I also saw you had another question about MLP implementations and why these are copied over :) I think it was a great question. Generally, we've tried to decouple the builders for different models as much as possible. This does lead to copy pasting some code, but generally makes things easy to handle, maintain, extend and ultimately deprecate. If you try to squeeze in too many things into a single implementation, ultimately those become bloated and full of conditionals. This makes any sort of extensions or refactors hard. Over time, we may find opportunities to consolidate and merge code - but thats an easier operation than splitting things to prevent complexity from increasing since this will likely break tons of users. Hope this makes sense. Happy to answer more questions!

@SalmanMohammadi
Copy link
Collaborator Author

I love the support and positivity @kartikayk :)

I've put a PR up for a (hopefully) pretty lightweight and non-invasive TransformerClassifier implementation. I could use some guidance on numerical testing. I'd be happy to also add correctness tests for base mistral, and then the mistral classifier.

I'd love collaborate on adding this if you'd be interested

I think it should be pretty straightforward to add a recipe for training this classifier on a reward modelling task! I'd be happy to hear your thoughts on anything that's missing. I mentioned in the PR that we could start with a recipe using the classifier and the dataset that was implemented for DPO.

Generally, we've tried to decouple the builders for different models as much as possible.

I ended up answering my own question after reading the codebase. It's great to hear your thoughts. There's always a little SWE inside of me complaining about code duplication : ) I think the other advantage of the kind of low-coupling, high-modularity code you mentioned is interpretability. I could easily figure out where the implementation details were for an architecture I was interested in. This is imo a seriously underrated feature of an open-source, popular ML codebase. It makes a huge difference to every level of expertise of user, and particularly users coming from a non-SWE background who want to understand how things work on a more technical level.

Next steps
It'd be good to talk more about implementing reward model training. Once we've worked through the TransformerClassifier testing and the PR looks good, I'll hopefully have most of the components I need to implement PPO too. I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointers!
On a more general note, I'd also be happy to help write tutorials/documentation on the things we're working on.

@kartikayk
Copy link
Contributor

Awesome, love the PR @SalmanMohammadi! I'll review in a bit, but see that you already have a discussion going!

I could easily figure out where the implementation details were for an architecture I was interested in.

You hit the nail on the head. This was exactly the intent, and I'm glad it resonates. It's one of the design principles we did have much discussion and debate on :)

I don't currently have resources to test or train larger models - if you have suggestions for cheap cloud compute/compute for open-source development I'd appreciate any pointer

I've been using runpod for my own development and testing. Let me know if this works for you? Of course we'de be happy to do some testing on larger models as well and share all of the learnings and observations with you as well.

This is really exciting! Thanks for helping shape this up. I'm looking forward to sharing this with the community :)

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 30, 2024

@kartikayk the TransformerClassifier PR is pretty much good to go. Would you still like to collaborate on the RLHF process? There's a lot of steps and I have some design docs I could share on the different components we need. Happy to chat here or on Discord to share some of my draft ideas!

@kartikayk
Copy link
Contributor

@SalmanMohammadi I'm still very interested in the actual training! We can create a sidebar on discord to chat about this so other interested folks can follow along as well. WDYT?

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 30, 2024

Sounds good! Let me know what you're interested in and I can share my thoughts/updates on what I'm working on. Let's chat more on Discord.

@kartikayk
Copy link
Contributor

Sounds good! Mind sharing your discord handle? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Request for comments
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants