Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RLHF Reward Model #1315

Closed
wants to merge 10 commits into from

Conversation

tcbegley
Copy link
Contributor

@tcbegley tcbegley commented Jun 26, 2023

This PR builds on top of #1309, adding the GPT2RewardModel class and tests. Changes from that PR are needed so that we can use the dataloaders in the tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2023
@tcbegley tcbegley changed the base branch from main to rlhf_data June 26, 2023 10:35
Copy link
Contributor

@apbard apbard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we rebase on #1309 to make sure tests are green?

torchrl/modules/models/rlhf.py Outdated Show resolved Hide resolved
torchrl/modules/models/rlhf.py Show resolved Hide resolved
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Copy link
Contributor

@apbard apbard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left couple of comments. but LGTM! thanks!

torchrl/modules/models/rlhf.py Show resolved Hide resolved
torchrl/modules/models/rlhf.py Show resolved Hide resolved
Co-authored-by: Alessandro Pietro Bardelli <apbard@users.noreply.github.com>
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work thanks a lot for this.
A couple of comments regarding efficiency: I dont think we need to tackle this now but let's make sure we keep track somewhere that there could be room for improvement

torchrl/modules/models/rlhf.py Outdated Show resolved Hide resolved
Comment on lines 32 to 33
""" Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token
"""
"""Computes the rewards associated with some encoded sequence of tokens.
Returns a tuple (rewards, end_scores) where `rewards` contains all rewards computed at each timestep, `end_scores` contains the reward computed at the last-non-padding token
"""

torchrl/modules/models/rlhf.py Outdated Show resolved Hide resolved
torchrl/modules/models/rlhf.py Show resolved Hide resolved
Comment on lines 35 to 45
hidden_states = outputs[0]
rewards = self.lm_head(hidden_states).squeeze(-1)
end_scores = []
bs = input_ids.shape[0]

for i in range(bs):
pad_inds = (input_ids[i] == self.PAD_ID).nonzero()
first_pad_ind = (
pad_inds[0].item() if len(pad_inds) > 0 else input_ids.shape[1]
)
end_scores.append(rewards[i, first_pad_ind - 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition is that there is a better (more efficient way) of coding that loop.
Let's move it to a private method such that we can easily refactor this piece of code and compare the results

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, assuming that padding is only on the right:

>>> import torch
>>> z = torch.arange(12).view(3, 4)
>>> z[0, 2:] = 100
>>> z[1, 3:] = 100
>>> z[2, 1:] = 100
>>> mask = z == 100
>>> mask = torch.cat([mask, torch.ones_like(mask[..., :1])], -1) # make sure that there is one True on each row
>>> first_pad = mask[..., :-1] ^ mask[..., 1:]
>>> first_pad = first_pad.nonzero()
>>> first_pad = first_pad[:, -1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've factored the slow code out into a private method as you suggested. Did I understand right that we should land as is and follow up with speed improvements?


for i in range(bs):
# Check if there is any padding otherwise take length of sequence
c_inds = (chosen_ids[i] == pad_token_id).nonzero()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general nonzero is expensive and shouldn't be called to often

end_ind = max(c_ind, r_ind)

# Retrieve first index where trajectories diverge
divergence_ind = (chosen_ids[i] != rejected_ids[i]).nonzero()[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, nonzero is expensive

torchrl/modules/models/rlhf.py Show resolved Hide resolved
The loss is computed as loss = -log_sigmoid(chosen_reward - rejected_reward).
This loss is small when the reward model favours the chosen data and large if
the model favours the rejected data.
Note: the loss is computed excluding the common "prefix" subsequence to effectively disregard contribution of the original prompt.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps an example of a call to this?

torchrl/modules/models/rlhf.py Show resolved Hide resolved
vmoens and others added 4 commits June 27, 2023 08:34
@apbard apbard mentioned this pull request Jun 27, 2023
@vmoens vmoens deleted the branch pytorch:rlhf_data June 27, 2023 17:16
@vmoens vmoens closed this Jun 27, 2023
@tcbegley tcbegley changed the title [Feature, NOMERGE] RLHF Reward Model [Feature] RLHF Reward Model Jun 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants