-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] RLHF Reward Model #1315
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we rebase on #1309 to make sure tests are green?
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left couple of comments. but LGTM! thanks!
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work thanks a lot for this.
A couple of comments regarding efficiency: I dont think we need to tackle this now but let's make sure we keep track somewhere that there could be room for improvement
torchrl/modules/models/rlhf.py
Outdated
""" Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token | |
""" | |
"""Computes the rewards associated with some encoded sequence of tokens. | |
Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token | |
""" |
torchrl/modules/models/rlhf.py
Outdated
hidden_states = outputs[0] | ||
rewards = self.lm_head(hidden_states).squeeze(-1) | ||
end_scores = [] | ||
bs = input_ids.shape[0] | ||
|
||
for i in range(bs): | ||
pad_inds = (input_ids[i] == self.PAD_ID).nonzero() | ||
first_pad_ind = ( | ||
pad_inds[0].item() if len(pad_inds) > 0 else input_ids.shape[1] | ||
) | ||
end_scores.append(rewards[i, first_pad_ind - 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intuition is that there is a better (more efficient way) of coding that loop.
Let's move it to a private method such that we can easily refactor this piece of code and compare the results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, assuming that padding is only on the right:
>>> import torch
>>> z = torch.arange(12).view(3, 4)
>>> z[0, 2:] = 100
>>> z[1, 3:] = 100
>>> z[2, 1:] = 100
>>> mask = z == 100
>>> mask = torch.cat([mask, torch.ones_like(mask[..., :1])], -1) # make sure that there is one True on each row
>>> first_pad = mask[..., :-1] ^ mask[..., 1:]
>>> first_pad = first_pad.nonzero()
>>> first_pad = first_pad[:, -1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I've factored the slow code out into a private method as you suggested. Did I understand right that we should land as is and follow up with speed improvements?
|
||
for i in range(bs): | ||
# Check if there is any padding otherwise take length of sequence | ||
c_inds = (chosen_ids[i] == pad_token_id).nonzero() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general nonzero is expensive and shouldn't be called to often
end_ind = max(c_ind, r_ind) | ||
|
||
# Retrieve first index where trajectories diverge | ||
divergence_ind = (chosen_ids[i] != rejected_ids[i]).nonzero()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, nonzero is expensive
torchrl/modules/models/rlhf.py
Outdated
The loss is computed as loss = -log_sigmoid(chosen_reward - rejected_reward). | ||
This loss is small when the reward model favours the chosen data and large if | ||
the model favours the rejected data. | ||
Note: the loss is computed excluding the common "prefix" subsequence to effectively disregard contribution of the original prompt. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps an example of a call to this?
# Conflicts: # test/test_rlhf.py
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
This PR builds on top of #1309, adding the
GPT2RewardModel
class and tests. Changes from that PR are needed so that we can use the dataloaders in the tests.