From 000bb70fc5adf41f1612d9ceb59978b4b8fbe33e Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 16 Sep 2024 03:21:39 +0100 Subject: [PATCH] Move RLHF out of modules (#1591) --- docs/source/api_ref_modules.rst | 17 ---------------- docs/source/api_ref_rlhf.rst | 20 +++++++++++++++++++ docs/source/index.rst | 1 + recipes/configs/llama2/7B_lora_dpo.yaml | 2 +- .../llama2/7B_lora_dpo_single_device.yaml | 2 +- .../mistral/7B_full_ppo_low_memory.yaml | 2 +- recipes/lora_dpo_distributed.py | 15 +++++++------- recipes/lora_dpo_single_device.py | 17 ++++++++-------- recipes/ppo_full_finetune_single_device.py | 11 +++++----- .../modules/rlhf/loss/test_dpo_loss.py | 2 +- .../modules/rlhf/loss/test_ppo_loss.py | 2 +- .../torchtune/modules/rlhf/test_generation.py | 4 ++-- tests/torchtune/modules/rlhf/test_rewards.py | 6 +++--- .../modules/rlhf/test_sequence_processing.py | 2 +- torchtune/{modules => }/rlhf/__init__.py | 0 torchtune/{modules => }/rlhf/_generation.py | 0 torchtune/{modules => }/rlhf/_types.py | 0 torchtune/{modules => }/rlhf/loss/__init__.py | 0 torchtune/{modules => }/rlhf/loss/dpo.py | 2 +- torchtune/{modules => }/rlhf/loss/ppo.py | 2 +- torchtune/{modules => }/rlhf/rewards.py | 2 +- .../{modules => }/rlhf/sequence_processing.py | 2 +- .../{modules => }/rlhf/utils/__init__.py | 0 .../rlhf/utils/_convert_weights.py | 0 .../training/checkpointing/_checkpointer.py | 2 +- 25 files changed, 57 insertions(+), 56 deletions(-) create mode 100644 docs/source/api_ref_rlhf.rst rename torchtune/{modules => }/rlhf/__init__.py (100%) rename torchtune/{modules => }/rlhf/_generation.py (100%) rename torchtune/{modules => }/rlhf/_types.py (100%) rename torchtune/{modules => }/rlhf/loss/__init__.py (100%) rename torchtune/{modules => }/rlhf/loss/dpo.py (99%) rename torchtune/{modules => }/rlhf/loss/ppo.py (99%) rename torchtune/{modules => }/rlhf/rewards.py (98%) rename torchtune/{modules => }/rlhf/sequence_processing.py (99%) rename torchtune/{modules => }/rlhf/utils/__init__.py (100%) rename torchtune/{modules => }/rlhf/utils/_convert_weights.py (100%) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index d1fb60144a..76dae8600d 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -113,20 +113,3 @@ Functions used for preprocessing images. transforms.Transform transforms.VisionCrossAttentionMask - -Reinforcement Learning From Human Feedback (RLHF) --------------------------------------------------- -Components and losses for RLHF algorithms like PPO and DPO: - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - rlhf.estimate_advantages - rlhf.get_rewards_ppo - rlhf.truncate_sequence_at_first_stop_token - rlhf.loss.PPOLoss - rlhf.loss.DPOLoss - rlhf.loss.RSOLoss - rlhf.loss.IPOLoss - rlhf.loss.SimPOLoss diff --git a/docs/source/api_ref_rlhf.rst b/docs/source/api_ref_rlhf.rst new file mode 100644 index 0000000000..e68ca8aed1 --- /dev/null +++ b/docs/source/api_ref_rlhf.rst @@ -0,0 +1,20 @@ +=============== +torchtune.rlhf +=============== + +.. currentmodule:: torchtune.rlhf + +Components and losses for RLHF algorithms like PPO and DPO: + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + estimate_advantages + get_rewards_ppo + truncate_sequence_at_first_stop_token + loss.PPOLoss + loss.DPOLoss + loss.RSOLoss + loss.IPOLoss + loss.SimPOLoss diff --git a/docs/source/index.rst b/docs/source/index.rst index 98676e4ce9..f3a95281fe 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -155,5 +155,6 @@ torchtune tutorials. api_ref_generation api_ref_models api_ref_modules + api_ref_rlhf api_ref_training api_ref_utilities diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index 78b16c6e04..f6acfcb76e 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -62,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.rlhf.loss.DPOLoss + _component_: torchtune.rlhf.loss.DPOLoss beta: 0.1 label_smoothing: 0 diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index 5fb50ff6e8..d51847d1e7 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -61,7 +61,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.rlhf.loss.DPOLoss + _component_: torchtune.rlhf.loss.DPOLoss beta: 0.1 label_smoothing: 0 diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml index adf11ce486..1cf7dd974a 100644 --- a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -167,7 +167,7 @@ lmbda: 0.95 # PPO hyperparameters loss: - _component_: torchtune.modules.rlhf.loss.PPOLoss + _component_: torchtune.rlhf.loss.PPOLoss epsilon: 0.2 value_coeff: 0.1 value_clip_range: 0.2 diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 1d928e4fc8..d381dc9832 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -18,10 +18,9 @@ from torch.distributed import destroy_process_group, init_process_group from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils +from torchtune import config, modules, rlhf, training, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo from torchtune.datasets import ConcatDataset -from torchtune.modules import rlhf from torchtune.modules.peft import ( disable_adapter, DoRALinear, @@ -32,8 +31,8 @@ set_trainable_params, validate_missing_and_unexpected_for_lora, ) -from torchtune.modules.rlhf.loss import SimPOLoss from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.rlhf.loss import SimPOLoss from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -95,10 +94,10 @@ class LoRADPORecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. The following losses are supported in this recipe: - - :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). - - :class:`~torchtune.modules.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). - - :class:`~torchtune.modules.rlhf.loss.IPO`: Identity Preference Optimization (IPO). - - :class:`~torchtune.modules.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + - :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + - :class:`~torchtune.rlhf.loss.IPO`: Identity Preference Optimization (IPO). + - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -582,7 +581,7 @@ def concatenated_forward( all_log_probs = rlhf.get_batch_log_probs( all_logits, concatenated_labels, - # see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss` + # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss` return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), ) diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 8411d8088d..4cb53b2afb 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -18,10 +18,9 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils +from torchtune import config, modules, rlhf, training, utils from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo from torchtune.datasets import ConcatDataset -from torchtune.modules import rlhf from torchtune.modules.peft import ( disable_adapter, get_adapter_params, @@ -30,9 +29,9 @@ validate_missing_and_unexpected_for_lora, validate_state_dict_for_lora, ) - -from torchtune.modules.rlhf.loss import SimPOLoss from torchtune.recipe_interfaces import FTRecipeInterface + +from torchtune.rlhf.loss import SimPOLoss from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -56,10 +55,10 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): The following losses are supported in this recipe: - - :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). - - :class:`~torchtune.modules.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). - - :class:`~torchtune.modules.rlhf.loss.IPOLoss`: Identity Preference Optimization (IPO). - - :class:`~torchtune.modules.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + - :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + - :class:`~torchtune.rlhf.loss.IPOLoss`: Identity Preference Optimization (IPO). + - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). Assumptions: - Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done @@ -471,7 +470,7 @@ def concatenated_forward( all_log_probs = rlhf.get_batch_log_probs( all_logits, concatenated_labels, - # see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss` + # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss` return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), ) diff --git a/recipes/ppo_full_finetune_single_device.py b/recipes/ppo_full_finetune_single_device.py index bdd63e8cdc..b9840fc067 100644 --- a/recipes/ppo_full_finetune_single_device.py +++ b/recipes/ppo_full_finetune_single_device.py @@ -17,12 +17,11 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils +from torchtune import config, modules, rlhf, training, utils from torchtune.data import padded_collate from torchtune.datasets import ConcatDataset -from torchtune.modules import rlhf -from torchtune.modules.rlhf import PPOStats, Trajectory from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.rlhf import PPOStats, Trajectory from tqdm import tqdm @@ -680,7 +679,7 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] Returns: - Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory` comprising + Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory` comprising the current trajectory. """ batch_size, context_length = input_ids.shape @@ -799,7 +798,7 @@ def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] Returns: - Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory`, comprising + Trajectory: An instance of :class:`~torchtune.rlhf.Trajectory`, comprising the current trajectory. """ trajectories: List[Trajectory] = [] @@ -947,7 +946,7 @@ def _ppo_step( context_length (int): input ids sequence length Returns: - PPOStats: An instance of :class:`~torchtune.modules.rlhf.PPOStats`, a NamedTuple containing: + PPOStats: An instance of :class:`~torchtune.rlhf.PPOStats`, a NamedTuple containing: - loss (torch.Tensor): The total PPO loss. - policy_loss (torch.Tensor): The policy function loss. - value_loss (torch.Tensor): The value function loss. diff --git a/tests/torchtune/modules/rlhf/loss/test_dpo_loss.py b/tests/torchtune/modules/rlhf/loss/test_dpo_loss.py index ef82aad873..e56621f3ed 100644 --- a/tests/torchtune/modules/rlhf/loss/test_dpo_loss.py +++ b/tests/torchtune/modules/rlhf/loss/test_dpo_loss.py @@ -6,7 +6,7 @@ import pytest import torch -from torchtune.modules.rlhf.loss import DPOLoss, IPOLoss, RSOLoss, SimPOLoss +from torchtune.rlhf.loss import DPOLoss, IPOLoss, RSOLoss, SimPOLoss @pytest.fixture(autouse=True) diff --git a/tests/torchtune/modules/rlhf/loss/test_ppo_loss.py b/tests/torchtune/modules/rlhf/loss/test_ppo_loss.py index 97a45dace2..01a2388770 100644 --- a/tests/torchtune/modules/rlhf/loss/test_ppo_loss.py +++ b/tests/torchtune/modules/rlhf/loss/test_ppo_loss.py @@ -6,7 +6,7 @@ import pytest import torch -from torchtune.modules.rlhf.loss import PPOLoss +from torchtune.rlhf.loss import PPOLoss @pytest.fixture(autouse=True) diff --git a/tests/torchtune/modules/rlhf/test_generation.py b/tests/torchtune/modules/rlhf/test_generation.py index bfe8237c0c..511ecfcdc4 100644 --- a/tests/torchtune/modules/rlhf/test_generation.py +++ b/tests/torchtune/modules/rlhf/test_generation.py @@ -8,9 +8,9 @@ import torch from tests.test_utils import fixed_init_model +from torchtune import rlhf from torchtune.generation._generation import sample from torchtune.models.llama2 import llama2 -from torchtune.modules import rlhf class TestGenerateNextTokenWithLogits: @@ -61,7 +61,7 @@ def test_generate_next_token_with_logits(self, generation_model): class TestGenerate: """ - Test class for text generation functionality in :func:`~torchtune.modules.rlhf.generate`. + Test class for text generation functionality in :func:`~torchtune.rlhf.generate`. See `torchtune.tests.utils.test_generation` for context. """ diff --git a/tests/torchtune/modules/rlhf/test_rewards.py b/tests/torchtune/modules/rlhf/test_rewards.py index 0e8ec998fa..4284d1d63d 100644 --- a/tests/torchtune/modules/rlhf/test_rewards.py +++ b/tests/torchtune/modules/rlhf/test_rewards.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torchtune.modules import rlhf +from torchtune import rlhf class TestGetRewards: @@ -182,7 +182,7 @@ def test_estimate_advantages_with_whitening(self): ] ) - # see `torchtune.modules.rlhf.estimate_advantages` + # see `torchtune.rlhf.estimate_advantages` expected_advantages = returns - values expected_whitened_advantages = rlhf.whiten(expected_advantages, shift_mean=True) advantages, _ = rlhf.estimate_advantages(values, rewards, gamma, lmbda) @@ -209,7 +209,7 @@ def test_estimate_advantages_with_masks(self): ] ) - # see `torchtune.modules.rlhf.estimate_advantages` + # see `torchtune.rlhf.estimate_advantages` expected_advantages = returns - values expected_advantages = rlhf.whiten(expected_advantages, mask=masks) expected_advantages[..., -1] = 0.0 diff --git a/tests/torchtune/modules/rlhf/test_sequence_processing.py b/tests/torchtune/modules/rlhf/test_sequence_processing.py index 43accdf80c..ae53e6494f 100644 --- a/tests/torchtune/modules/rlhf/test_sequence_processing.py +++ b/tests/torchtune/modules/rlhf/test_sequence_processing.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torchtune.modules import rlhf +from torchtune import rlhf class TestTruncateSequenceAtFirstStopToken: diff --git a/torchtune/modules/rlhf/__init__.py b/torchtune/rlhf/__init__.py similarity index 100% rename from torchtune/modules/rlhf/__init__.py rename to torchtune/rlhf/__init__.py diff --git a/torchtune/modules/rlhf/_generation.py b/torchtune/rlhf/_generation.py similarity index 100% rename from torchtune/modules/rlhf/_generation.py rename to torchtune/rlhf/_generation.py diff --git a/torchtune/modules/rlhf/_types.py b/torchtune/rlhf/_types.py similarity index 100% rename from torchtune/modules/rlhf/_types.py rename to torchtune/rlhf/_types.py diff --git a/torchtune/modules/rlhf/loss/__init__.py b/torchtune/rlhf/loss/__init__.py similarity index 100% rename from torchtune/modules/rlhf/loss/__init__.py rename to torchtune/rlhf/loss/__init__.py diff --git a/torchtune/modules/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py similarity index 99% rename from torchtune/modules/rlhf/loss/dpo.py rename to torchtune/rlhf/loss/dpo.py index d1ca35cd11..d49dbad676 100644 --- a/torchtune/modules/rlhf/loss/dpo.py +++ b/torchtune/rlhf/loss/dpo.py @@ -255,7 +255,7 @@ class SimPOLoss(nn.Module): SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize the policy during training. It also uses a target reward margin to guide the policy towards better responses. - This is kind of the same intuition as in :class:`~torchtune.modules.rlhf.loss.IPOLoss`, but instead of optimizing against + This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and rejected responses. diff --git a/torchtune/modules/rlhf/loss/ppo.py b/torchtune/rlhf/loss/ppo.py similarity index 99% rename from torchtune/modules/rlhf/loss/ppo.py rename to torchtune/rlhf/loss/ppo.py index 0cef4a5301..d4770802f7 100644 --- a/torchtune/modules/rlhf/loss/ppo.py +++ b/torchtune/rlhf/loss/ppo.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from torchtune.modules import rlhf +from torchtune import rlhf class PPOLoss(nn.Module): diff --git a/torchtune/modules/rlhf/rewards.py b/torchtune/rlhf/rewards.py similarity index 98% rename from torchtune/modules/rlhf/rewards.py rename to torchtune/rlhf/rewards.py index dd9f970376..f0e42ca58c 100644 --- a/torchtune/modules/rlhf/rewards.py +++ b/torchtune/rlhf/rewards.py @@ -19,7 +19,7 @@ def get_reward_penalty_mask( Calculates a mask to penalise scores corresponding to sequences generated during PPO, where True indicates the score at the corresponding position should be penalised. This function assumes sequences have already been truncated at an EOS, if present, and padded to length, - e.g. by :func:`torchtune.modules.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`. + e.g. by :func:`torchtune.rlhf.sequence_processing.truncate_sequence_at_first_stop_token`. Scores are penalised such that: - If ``min_response_length`` is set, scores for sequences with ``length < min_response_length`` are penalised. diff --git a/torchtune/modules/rlhf/sequence_processing.py b/torchtune/rlhf/sequence_processing.py similarity index 99% rename from torchtune/modules/rlhf/sequence_processing.py rename to torchtune/rlhf/sequence_processing.py index adbce7cd6c..9844dd001c 100644 --- a/torchtune/modules/rlhf/sequence_processing.py +++ b/torchtune/rlhf/sequence_processing.py @@ -8,8 +8,8 @@ import torch import torch.nn.functional as F +from torchtune import rlhf from torchtune.data import CROSS_ENTROPY_IGNORE_IDX -from torchtune.modules import rlhf def truncate_sequence_at_first_stop_token( diff --git a/torchtune/modules/rlhf/utils/__init__.py b/torchtune/rlhf/utils/__init__.py similarity index 100% rename from torchtune/modules/rlhf/utils/__init__.py rename to torchtune/rlhf/utils/__init__.py diff --git a/torchtune/modules/rlhf/utils/_convert_weights.py b/torchtune/rlhf/utils/_convert_weights.py similarity index 100% rename from torchtune/modules/rlhf/utils/_convert_weights.py rename to torchtune/rlhf/utils/_convert_weights.py diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index fc4a3d2816..ad681db5a5 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -18,7 +18,7 @@ from torchtune.models import convert_weights from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf -from torchtune.modules.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf +from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf from torchtune.training.checkpointing._utils import ( get_path, ModelType,