-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
custom reward function support for ppo trainer #2540
base: main
Are you sure you want to change the base?
custom reward function support for ppo trainer #2540
Conversation
trl/trainer/utils.py
Outdated
@@ -1049,14 +1049,20 @@ def first_true_indices(bools: torch.Tensor, dtype=torch.long): | |||
|
|||
|
|||
def get_reward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where the primary change are:
modifying the get_reward function to work with both a nn.Module
and a Callable
.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Question:Does this theoretically work? I'm asking because I haven't read the PPO papers. When the PPO trainer is training, it outputs: For example, let's say the custom reward function is based on the count of a specific word, like "Good": def reward_function(texts):
rewards = [text.count("good") for text in texts]
return rewards and the printed output is just the count of the word good in the text and it looks normal since it's in the same format. But is there more to it? theoretically? |
trl/trainer/ppo_trainer.py
Outdated
_, score, _ = get_reward( | ||
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length | ||
self.reward_model, | ||
processing_class, | ||
postprocessed_query_response, | ||
processing_class.pad_token_id, | ||
context_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the if isinstance(model, torch.nn.Module):
here? I would allow not to introduce breaking change in get_reward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to clarify what you mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, it wasn't clear:
something like this instead:
if isinstance(model, torch.nn.Module):
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
else:
full_value = ...
doing such we don't introduce a breaking change in get_reward
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean I changed get_reward to work either way with both a callable and an nn.Module.
So you want to add if isinstance(model, torch.nn.Module)
there and keep get_reward
as it is without change?
Can you add a test as well? |
I'll take that as a yes. Yes I will add the test and the docs later, maybe a blogpost or something to show how it works if I don't run out of resources. |
Thanks for the contribution! We look forward to this flexibility added! |
@qgallouedec |
""" | ||
This function ensures that the custom reward function produces the correct output structure for integration with the trainer script. | ||
""" | ||
texts = processor.batch_decode(query_responses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we skip special tokens here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
another concern:
Currently, the postprocessed_query_response
includes both the prompt and the generated response, which are then scored by the reward model(or custom reward function). Should the reward model only score the generated response, or should it score both the prompt and the response together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it score both the prompt and the response together
Yes, and the reason is quite intuitive. For example, consider a prompt like 2+3=
with the generated response being 6
. If the reward model only has access to the generated response (6
), how could it determine whether the calculation is correct without knowing the prompt?
Yes it looks better imo! |
unwrapped_value_model, | ||
query_response, | ||
processing_class.pad_token_id, | ||
context_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unwrapped_value_model, | |
query_response, | |
processing_class.pad_token_id, | |
context_length, | |
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length |
You could also add a test, with, eg, this reward fund def reward_func(text):
return float(len(text)) |
Hi, does this same change work for the RLOO trainer? (aka do the same change to the rloo trainer) |
We will apply the same approach to RLOO after conducting some tests on this. |
@qgallouedec correct me if I'm wrong, but the trl/trl/trainer/ppo_trainer.py Lines 459 to 461 in 88514d5
|
I tested the branch locally and seems like it works fine. |
wut?🤔 Could you share the code you're using that works for you? I'm currently considering how the value model would be replaced or used, what the |
No, I don't think you need to modify anything related to the value function here.
In fact, it's the same with any reward function. The value model is trained to estimates the value of a state (= token), i.e. the expectation of the discounted future rewards. It seems to me that the current implementation is sufficient. |
@qgallouedec trl/trl/trainer/ppo_trainer.py Line 119 in d9f0568
but it it breaks at PolicyAndValueWrapper if I don't provide a value_model . I also can't use the reward function as the value_model .
|
Oh i get what you mean, @August-murr @qgallouedec I think typically the value model is initialized same as the reward model trained, but when we dont have the reward model what do we specify as the value model? |
In the meanwhile I implemented the same for RLOO and can confirm it works (apparently) |
What @Superskyyy said is a real concern. We still need Value model. And probably trained just like a reward model? Would using the SFT pre-trained base as Value model work? Anyone has experimented with this? |
In theory it should. There are two ways of initiating a value model, either from policy or the trained reward. |
What does this PR do?
Fixes #2518
Adding support for a custom reward function for the PPO trainer.
How it works
Write a custom function that takes a list of texts as input, representing a batch of responses, and outputs a list of scores.
I will add more documentation and explanations later after running several tests to make sure the implementation is functional.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@qgallouedec