Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔀 Add MergeModelCallBack #2282

Merged
merged 48 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
63f7ae0
Create mergekit_utils.py
August-murr Oct 25, 2024
55f3a25
adding mergekit as an optional dependancy
August-murr Oct 25, 2024
9f829fa
adding MergeModel to callbacks
August-murr Oct 25, 2024
f666369
adding mergekit_utils dependencies to callbacks
August-murr Oct 25, 2024
bb52d4b
setting lower bound for mergekit
August-murr Oct 30, 2024
5d9a5a9
setting mergekit lower band to 0.0.5.1
August-murr Oct 30, 2024
9ba148c
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 5, 2024
0fe2d67
adding support for MergeModelCallBack __init__.py
August-murr Nov 14, 2024
caee4b2
adding support for mergemodelcallback
August-murr Nov 14, 2024
b278e5a
mergemodelcallback tests
August-murr Nov 14, 2024
311a27a
Update callbacks.py
August-murr Nov 14, 2024
ec8b5ec
Update __init__.py
August-murr Nov 14, 2024
147b188
Update __init__.py
August-murr Nov 14, 2024
7f5dbc1
Update test_callbacks.py
August-murr Nov 14, 2024
c17f8e3
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
0d9c8a0
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
ca07b42
Update trl/trainer/callbacks.py
August-murr Nov 15, 2024
6229885
using different dataset for tests
August-murr Nov 15, 2024
9ea5a6b
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
1a2b425
Update trl/mergekit_utils.py
August-murr Nov 15, 2024
f0f84eb
Merge branch 'main' into MergeModelCallBack
qgallouedec Nov 15, 2024
c924cb9
Apply suggestions from code review
August-murr Nov 15, 2024
8b5a4a9
replacing get_last_checkpoint
August-murr Nov 15, 2024
34e94ff
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
August-murr Nov 15, 2024
0a5db60
renaming Merge to merge_models
August-murr Nov 15, 2024
906eafa
setting mergers default value to linear
August-murr Nov 15, 2024
8d03608
removing unnecessary docs and comments
August-murr Nov 15, 2024
eb66e99
adding docstring to Mergeconfig
August-murr Nov 15, 2024
eb7b228
adding mergekits link to docstring
August-murr Nov 15, 2024
1057c59
precommit
August-murr Nov 15, 2024
1c85ee5
removing duplicated import
August-murr Nov 16, 2024
18d0388
typos in mergekit_utils docstring
August-murr Nov 16, 2024
ca8f361
fixing tests
August-murr Nov 17, 2024
0a25ee8
making mergemodelcallback tests optional
August-murr Nov 18, 2024
cd76890
Make import optional
qgallouedec Nov 18, 2024
c8afdbc
minor
qgallouedec Nov 18, 2024
26aa418
Merge branch 'main' into MergeModelCallBack
kashif Nov 20, 2024
fb12119
Merge branch 'MergeModelCallBack' of https://github.com/August-murr/t…
qgallouedec Nov 21, 2024
88de59a
use tmp dir in test
qgallouedec Nov 21, 2024
fed6045
sort
qgallouedec Nov 21, 2024
01e38c8
Add import error checks for mergekit extra
qgallouedec Nov 21, 2024
9e7a068
use a common _merge_and_maybe_push method and compat with windows path
qgallouedec Nov 21, 2024
fa5bafe
debug windows
qgallouedec Nov 21, 2024
47ecef5
Update dependencies for mergekit and add test dependencies
qgallouedec Nov 21, 2024
7fe26d5
Add assertion to check if merged folder exists in the last checkpoint
qgallouedec Nov 21, 2024
a57d88a
Fix temporary directory cleanup in test_callbacks.py
qgallouedec Nov 21, 2024
69ea0ed
Add sys import and skip test for Python versions below 3.10 due to cl…
qgallouedec Nov 21, 2024
d89eadc
revert change for debug
qgallouedec Nov 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/callbacks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@

## LogCompletionsCallback

[[autodoc]] LogCompletionsCallback
[[autodoc]] LogCompletionsCallback

## MergeModelCallback

[[autodoc]] MergeModelCallback
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@
# Windows support is partially supported with DeepSpeed https://github.com/microsoft/DeepSpeed/tree/master#windows
"deepspeed": ["deepspeed>=0.14.4; sys_platform != 'win32'"],
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
Expand Down
66 changes: 65 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
from transformers.testing_utils import require_peft, require_wandb
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_peft_available

from trl import BasePairwiseJudge, LogCompletionsCallback, WinRateCallback
from tests.testing_utils import require_mergekit
from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback
from trl.mergekit_utils import MergeConfig


if is_peft_available():
Expand Down Expand Up @@ -266,3 +269,64 @@ def test_basic(self):

# Check that the prompt is in the log
self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0])


@require_mergekit
class MergeModelCallbackTester(unittest.TestCase):
def setUp(self):
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

def test_callback(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
num_train_epochs=1,
report_to="none",
save_strategy="steps",
save_steps=1,
)
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
tokenizer=self.tokenizer,
callbacks=[merge_callback],
)
trainer.train()
last_checkpoint = get_last_checkpoint(tmp_dir)
merged_path = os.path.join(last_checkpoint, "merged")
self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.")

def test_every_checkpoint(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
num_train_epochs=1,
report_to="none",
save_strategy="steps",
save_steps=1,
)
config = MergeConfig()
merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True)
trainer = DPOTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
tokenizer=self.tokenizer,
callbacks=[merge_callback],
)
trainer.train()

checkpoints = sorted(
[os.path.join(tmp_dir, cp) for cp in os.listdir(tmp_dir) if cp.startswith("checkpoint-")]
)

for checkpoint in checkpoints:
merged_path = os.path.join(checkpoint, "merged")
self.assertTrue(
os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."
)
22 changes: 15 additions & 7 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import is_bitsandbytes_available, is_sklearn_available, is_wandb_available

from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available
from trl.import_utils import is_mergekit_available


# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use
Expand All @@ -35,6 +36,20 @@ def require_diffusers(test_case):
return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case)


def require_llm_blender(test_case):
"""
Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available.
"""
return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case)


def require_mergekit(test_case):
"""
Decorator marking a test that requires mergekit. Skips the test if mergekit is not available.
"""
return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case)


def require_no_wandb(test_case):
"""
Decorator marking a test that requires no wandb. Skips the test if wandb is available.
Expand All @@ -49,13 +64,6 @@ def require_sklearn(test_case):
return unittest.skipUnless(is_sklearn_available(), "test requires sklearn")(test_case)


def require_llm_blender(test_case):
"""
Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available.
"""
return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case)


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
Expand Down
4 changes: 3 additions & 1 deletion trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"KTOConfig",
"KTOTrainer",
"LogCompletionsCallback",
"MergeModelCallback",
"ModelConfig",
"NashMDConfig",
"NashMDTrainer",
Expand All @@ -91,7 +92,7 @@
"XPOConfig",
"XPOTrainer",
],
"trainer.callbacks": ["RichProgressCallback", "SyncRefModelCallback"],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
"utils": ["ScriptArguments"],
}
Expand Down Expand Up @@ -159,6 +160,7 @@
KTOConfig,
KTOTrainer,
LogCompletionsCallback,
MergeModelCallback,
ModelConfig,
NashMDConfig,
NashMDTrainer,
Expand Down
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_deepspeed_available = _is_package_available("deepspeed")
_diffusers_available = _is_package_available("diffusers")
_llm_blender_available = _is_package_available("llm_blender")
_mergekit_available = _is_package_available("mergekit")
_rich_available = _is_package_available("rich")
_unsloth_available = _is_package_available("unsloth")

Expand All @@ -40,6 +41,10 @@ def is_llm_blender_available() -> bool:
return _llm_blender_available


def is_mergekit_available() -> bool:
return _mergekit_available


def is_rich_available() -> bool:
return _rich_available

Expand Down
Loading
Loading