Skip to content

Commit

Permalink
add support for rpo_alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jun 3, 2024
1 parent 5cde065 commit 7f7c1a3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs

trl==0.8.6
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
zstandard==0.22.0
fastcore
13 changes: 11 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length

from axolotl.loraplus import create_loraplus_optimizer
Expand Down Expand Up @@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""


@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""


@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
Expand Down Expand Up @@ -1608,7 +1615,9 @@ def build_training_arguments(self, total_num_steps):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

training_args_cls = AxolotlTrainingArguments
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ class Config:
neftune_noise_alpha: Optional[float] = None

orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None

kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
Expand Down

0 comments on commit 7f7c1a3

Please sign in to comment.