Skip to content

Commit

Permalink
Move RLHF out of modules (#1591)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 16, 2024
1 parent 6820089 commit 000bb70
Show file tree
Hide file tree
Showing 25 changed files with 57 additions and 56 deletions.
17 changes: 0 additions & 17 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,3 @@ Functions used for preprocessing images.

transforms.Transform
transforms.VisionCrossAttentionMask

Reinforcement Learning From Human Feedback (RLHF)
--------------------------------------------------
Components and losses for RLHF algorithms like PPO and DPO:

.. autosummary::
:toctree: generated/
:nosignatures:

rlhf.estimate_advantages
rlhf.get_rewards_ppo
rlhf.truncate_sequence_at_first_stop_token
rlhf.loss.PPOLoss
rlhf.loss.DPOLoss
rlhf.loss.RSOLoss
rlhf.loss.IPOLoss
rlhf.loss.SimPOLoss
20 changes: 20 additions & 0 deletions docs/source/api_ref_rlhf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
===============
torchtune.rlhf
===============

.. currentmodule:: torchtune.rlhf

Components and losses for RLHF algorithms like PPO and DPO:

.. autosummary::
:toctree: generated/
:nosignatures:

estimate_advantages
get_rewards_ppo
truncate_sequence_at_first_stop_token
loss.PPOLoss
loss.DPOLoss
loss.RSOLoss
loss.IPOLoss
loss.SimPOLoss
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,6 @@ torchtune tutorials.
api_ref_generation
api_ref_models
api_ref_modules
api_ref_rlhf
api_ref_training
api_ref_utilities
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torchtune.modules.rlhf.loss.DPOLoss
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.1
label_smoothing: 0

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torchtune.modules.rlhf.loss.DPOLoss
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.1
label_smoothing: 0

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/mistral/7B_full_ppo_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ lmbda: 0.95

# PPO hyperparameters
loss:
_component_: torchtune.modules.rlhf.loss.PPOLoss
_component_: torchtune.rlhf.loss.PPOLoss
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.2
Expand Down
15 changes: 7 additions & 8 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from torch.distributed import destroy_process_group, init_process_group
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune import config, modules, rlhf, training, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft import (
disable_adapter,
DoRALinear,
Expand All @@ -32,8 +31,8 @@
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.modules.rlhf.loss import SimPOLoss
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -95,10 +94,10 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
The following losses are supported in this recipe:
- :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.modules.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.modules.rlhf.loss.IPO`: Identity Preference Optimization (IPO).
- :class:`~torchtune.modules.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.rlhf.loss.IPO`: Identity Preference Optimization (IPO).
- :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Expand Down Expand Up @@ -582,7 +581,7 @@ def concatenated_forward(
all_log_probs = rlhf.get_batch_log_probs(
all_logits,
concatenated_labels,
# see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss`
# see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
)

Expand Down
17 changes: 8 additions & 9 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune import config, modules, rlhf, training, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
Expand All @@ -30,9 +29,9 @@
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)

from torchtune.modules.rlhf.loss import SimPOLoss
from torchtune.recipe_interfaces import FTRecipeInterface

from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand All @@ -56,10 +55,10 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
The following losses are supported in this recipe:
- :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.modules.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.modules.rlhf.loss.IPOLoss`: Identity Preference Optimization (IPO).
- :class:`~torchtune.modules.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- :class:`~torchtune.rlhf.loss.IPOLoss`: Identity Preference Optimization (IPO).
- :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
Assumptions:
- Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done
Expand Down Expand Up @@ -471,7 +470,7 @@ def concatenated_forward(
all_log_probs = rlhf.get_batch_log_probs(
all_logits,
concatenated_labels,
# see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss`
# see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
)

Expand Down
11 changes: 5 additions & 6 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune import config, modules, rlhf, training, utils
from torchtune.data import padded_collate
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.rlhf import PPOStats, Trajectory
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import PPOStats, Trajectory
from tqdm import tqdm


Expand Down Expand Up @@ -680,7 +679,7 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory:
input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length]
Returns:
Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory` comprising
Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising
the current trajectory.
"""
batch_size, context_length = input_ids.shape
Expand Down Expand Up @@ -799,7 +798,7 @@ def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory:
input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length]
Returns:
Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory`, comprising
Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory`, comprising
the current trajectory.
"""
trajectories: List[Trajectory] = []
Expand Down Expand Up @@ -947,7 +946,7 @@ def _ppo_step(
context_length (int): input ids sequence length
Returns:
PPOStats: An instance of :class:`~torchtune.modules.rlhf.PPOStats`, a NamedTuple containing:
PPOStats: An instance of :class:`~torchtune.rlhf.PPOStats`, a NamedTuple containing:
- loss (torch.Tensor): The total PPO loss.
- policy_loss (torch.Tensor): The policy function loss.
- value_loss (torch.Tensor): The value function loss.
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/modules/rlhf/loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from torchtune.modules.rlhf.loss import DPOLoss, IPOLoss, RSOLoss, SimPOLoss
from torchtune.rlhf.loss import DPOLoss, IPOLoss, RSOLoss, SimPOLoss


@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/modules/rlhf/loss/test_ppo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from torchtune.modules.rlhf.loss import PPOLoss
from torchtune.rlhf.loss import PPOLoss


@pytest.fixture(autouse=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/torchtune/modules/rlhf/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import torch
from tests.test_utils import fixed_init_model
from torchtune import rlhf
from torchtune.generation._generation import sample
from torchtune.models.llama2 import llama2
from torchtune.modules import rlhf


class TestGenerateNextTokenWithLogits:
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_generate_next_token_with_logits(self, generation_model):

class TestGenerate:
"""
Test class for text generation functionality in :func:`~torchtune.modules.rlhf.generate`.
Test class for text generation functionality in :func:`~torchtune.rlhf.generate`.
See `torchtune.tests.utils.test_generation` for context.
"""

Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/rlhf/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtune.modules import rlhf
from torchtune import rlhf


class TestGetRewards:
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_estimate_advantages_with_whitening(self):
]
)

# see `torchtune.modules.rlhf.estimate_advantages`
# see `torchtune.rlhf.estimate_advantages`
expected_advantages = returns - values
expected_whitened_advantages = rlhf.whiten(expected_advantages, shift_mean=True)
advantages, _ = rlhf.estimate_advantages(values, rewards, gamma, lmbda)
Expand All @@ -209,7 +209,7 @@ def test_estimate_advantages_with_masks(self):
]
)

# see `torchtune.modules.rlhf.estimate_advantages`
# see `torchtune.rlhf.estimate_advantages`
expected_advantages = returns - values
expected_advantages = rlhf.whiten(expected_advantages, mask=masks)
expected_advantages[..., -1] = 0.0
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/modules/rlhf/test_sequence_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtune.modules import rlhf
from torchtune import rlhf


class TestTruncateSequenceAtFirstStopToken:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class SimPOLoss(nn.Module):
SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize
the policy during training. It also uses a target reward margin to guide the policy towards better responses.
This is kind of the same intuition as in :class:`~torchtune.modules.rlhf.loss.IPOLoss`, but instead of optimizing against
This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against
a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and
rejected responses.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.nn as nn
from torchtune.modules import rlhf
from torchtune import rlhf


class PPOLoss(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_reward_penalty_mask(
Calculates a mask to penalise scores corresponding to sequences generated during PPO, where True indicates the score
at the corresponding position should be penalised.
This function assumes sequences have already been truncated at an EOS, if present, and padded to length,
e.g. by :func:`torchtune.modules.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`.
e.g. by :func:`torchtune.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`.
Scores are penalised such that:
- If ``min_response_length`` is set, scores for sequences with ``length < min_response_length`` are penalised.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import torch
import torch.nn.functional as F
from torchtune import rlhf
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.modules import rlhf


def truncate_sequence_at_first_stop_token(
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchtune.models import convert_weights
from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf
from torchtune.modules.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf
from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf
from torchtune.training.checkpointing._utils import (
get_path,
ModelType,
Expand Down

0 comments on commit 000bb70

Please sign in to comment.