-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Only merge model weights in LoRA recipe when save_adapter_weights_only=False
#1476
Changes from 10 commits
e0fbabb
3ac8490
70a43f2
cd0a13f
b3930e3
7fc10f7
b8b805c
d4c9d22
4c8e4d4
05620fe
4991014
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -486,14 +486,16 @@ def _setup_data( | |
dataset=ds, | ||
sampler=sampler, | ||
batch_size=batch_size, | ||
collate_fn=partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
collate_fn=( | ||
partial( | ||
padded_collate_sft, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else partial( | ||
padded_collate_packed, | ||
) | ||
), | ||
) | ||
|
||
|
@@ -527,24 +529,34 @@ def save_checkpoint(self, epoch: int) -> None: | |
} | ||
) | ||
|
||
# Move to CPU to avoid a copy on GPU | ||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} | ||
|
||
# Construct the adapter weights | ||
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice | ||
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys | ||
adapter_key_filter = lambda x: x in self.adapter_params | ||
adapter_state_dict = { | ||
k: v for k, v in state_dict.items() if adapter_key_filter(k) | ||
} | ||
if not self._save_adapter_weights_only: | ||
# Construct the full state dict with LoRA weights merged into base LLM weights | ||
|
||
# Move to CPU to avoid a copy on GPU | ||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} | ||
|
||
# Construct the adapter weights | ||
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice | ||
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys | ||
adapter_state_dict = { | ||
k: v for k, v in state_dict.items() if adapter_key_filter(k) | ||
} | ||
|
||
merged_state_dict = get_merged_lora_ckpt( | ||
state_dict, | ||
rank=self._lora_rank, | ||
alpha=self._lora_alpha, | ||
) | ||
|
||
# Construct the full state dict with LoRA weights merged into base LLM weights | ||
merged_state_dict = get_merged_lora_ckpt( | ||
state_dict, | ||
rank=self._lora_rank, | ||
alpha=self._lora_alpha, | ||
) | ||
ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) | ||
ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) | ||
else: | ||
# No need to merge state dict if we're only saving adapter weights | ||
adapter_state_dict = { | ||
k: v | ||
for k, v in self._model.state_dict().items() | ||
if adapter_key_filter(k) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this still materialize the entire state dict though? |
||
|
||
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) | ||
adapter_config = { | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're a hero There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was just afraid someone would ask me to run the recipe manually, so this felt like less effort. |
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import runpy | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from omegaconf import OmegaConf | ||
from tests.common import TUNE_PATH | ||
from tests.recipes.utils import ( | ||
dummy_stack_exchange_dataset_config, | ||
MODEL_TEST_CONFIGS, | ||
write_hf_ckpt_config, | ||
) | ||
from tests.test_utils import ( | ||
CKPT_MODEL_PATHS, | ||
gen_log_file_name, | ||
get_loss_values_from_metric_logger, | ||
) | ||
from torchtune import config | ||
|
||
|
||
class TestLoRADPOSingleDeviceRecipe: | ||
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): | ||
return [ | ||
"batch_size=8", | ||
"device=cpu", | ||
f"dtype={dtype_str}", | ||
"enable_activation_checkpointing=False", | ||
"dataset.train_on_input=False", | ||
"seed=9", | ||
f"epochs={epochs}", | ||
"max_steps_per_epoch=2", | ||
"optimizer.lr=2e-5", | ||
"log_every_n_steps=1", | ||
"gradient_accumulation_steps=1", | ||
"clip_grad_norm=100", | ||
"tokenizer.max_seq_len=512", | ||
] + dummy_stack_exchange_dataset_config() | ||
|
||
@pytest.mark.parametrize("save_adapter_weights_only", [False, True]) | ||
@pytest.mark.integration_test | ||
def test_training_state_on_resume( | ||
self, tmpdir, monkeypatch, save_adapter_weights_only | ||
): | ||
"""Test whether the recipe state is correctly updated on resume. Since this | ||
is model agnostic, we should run this on the small model only. The test | ||
consists of three stages: | ||
- Train a model for 2 epochs | ||
- Resume training after epoch 1 | ||
- Make sure final loss matches the expected value of a model successfully resumed from a ckpt | ||
Unlike `tests.recipes.test_lora_finetune_single_device`, this test does not use pre-computed loss | ||
values to benchmark against. This test just ensures the loss values are identical when resuming. | ||
""" | ||
|
||
ckpt = "llama2_hf" | ||
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) | ||
ckpt_dir = ckpt_path.parent | ||
log_file = gen_log_file_name(tmpdir) | ||
|
||
# Config file needed for model conversion. | ||
# Create a second copy for training resume | ||
write_hf_ckpt_config(ckpt_dir) | ||
write_hf_ckpt_config(tmpdir) | ||
|
||
# Train for two epochs | ||
cmd_1 = f""" | ||
tune run lora_dpo_single_device \ | ||
--config llama2/7B_lora_dpo_single_device \ | ||
output_dir={tmpdir} \ | ||
checkpointer=torchtune.training.FullModelHFCheckpointer \ | ||
checkpointer.checkpoint_dir='{ckpt_dir}' \ | ||
checkpointer.checkpoint_files=[{ckpt_path}]\ | ||
checkpointer.output_dir={tmpdir} \ | ||
checkpointer.model_type=LLAMA2 \ | ||
tokenizer.path=/tmp/test-artifacts/tokenizer.model \ | ||
tokenizer.prompt_template=null \ | ||
save_adapter_weights_only={save_adapter_weights_only} \ | ||
metric_logger.filename={log_file} \ | ||
""".split() | ||
|
||
model_config = MODEL_TEST_CONFIGS["llama2_lora"] | ||
|
||
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config | ||
monkeypatch.setattr(sys, "argv", cmd_1) | ||
with pytest.raises(SystemExit, match=""): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
expected_loss_values = get_loss_values_from_metric_logger(log_file) | ||
|
||
resumed_log_dir = (tmpdir / "resumed/").mkdir() | ||
resumed_log_file = gen_log_file_name(resumed_log_dir) | ||
# Resume training | ||
cmd_2 = f""" | ||
tune run lora_dpo_single_device \ | ||
--config llama2/7B_lora_dpo_single_device \ | ||
output_dir={tmpdir} \ | ||
checkpointer=torchtune.training.FullModelHFCheckpointer \ | ||
checkpointer.checkpoint_dir={tmpdir} \ | ||
checkpointer.checkpoint_files=[{ckpt_path}]\ | ||
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} | ||
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} | ||
checkpointer.output_dir={tmpdir} \ | ||
checkpointer.model_type=LLAMA2 \ | ||
resume_from_checkpoint=True \ | ||
metric_logger.filename={resumed_log_file} \ | ||
tokenizer.path=/tmp/test-artifacts/tokenizer.model \ | ||
tokenizer.prompt_template=null \ | ||
""".split() | ||
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config | ||
monkeypatch.setattr(sys, "argv", cmd_2) | ||
with pytest.raises(SystemExit, match=""): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
# Second epoch only | ||
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file) | ||
|
||
torch.testing.assert_close( | ||
resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-5, atol=1e-5 | ||
) | ||
|
||
@pytest.mark.integration_test | ||
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): | ||
ckpt = "llama2_tune" | ||
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) | ||
ckpt_dir = ckpt_path.parent | ||
|
||
cmd = f""" | ||
tune run lora_dpo_single_device \ | ||
--config llama2/7B_lora_dpo_single_device \ | ||
output_dir={tmpdir} \ | ||
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ | ||
checkpointer.checkpoint_dir='{ckpt_dir}' \ | ||
checkpointer.checkpoint_files=[{ckpt_path}]\ | ||
checkpointer.output_dir={tmpdir} \ | ||
checkpointer.model_type=LLAMA2 \ | ||
tokenizer.path=/tmp/test-artifacts/tokenizer.model \ | ||
tokenizer.prompt_template=null \ | ||
""".split() | ||
|
||
model_config = MODEL_TEST_CONFIGS["llama2_lora"] | ||
|
||
cmd = cmd + self._get_test_config_overrides() + model_config | ||
monkeypatch.setattr(sys, "argv", cmd) | ||
with pytest.raises(SystemExit, match=""): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
# Next load both the merged weights in a Llama2 base model | ||
# and the base model weights + trained adapter weights in the LoRA Llama 2 model | ||
# The results of calling forward on dummy inputs should be the same. | ||
inputs = torch.randint(low=0, high=32_000, size=(2, 100)) | ||
|
||
# Build LoRA model for loading base + adapter weights separately | ||
lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) | ||
|
||
# Build base llama2 model for loading merged weights | ||
base_llama2_config = MODEL_TEST_CONFIGS["llama2"] | ||
llama2_model = config.instantiate( | ||
OmegaConf.from_dotlist(base_llama2_config).model | ||
) | ||
|
||
# Load base model and trained adapter weights into LoRA model and call fwd | ||
with open(f"{tmpdir}/adapter_1.pt", "rb") as f: | ||
lora_sd = torch.load(f, weights_only=True) | ||
with open(ckpt_path, "rb") as f: | ||
base_model_sd = torch.load(f, weights_only=True) | ||
lora_model.load_state_dict(lora_sd, strict=False) | ||
lora_model.load_state_dict(base_model_sd, strict=False) | ||
baseline_out = lora_model(inputs) | ||
|
||
# Load merged final ckpt directly into llama2 and call fwd | ||
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: | ||
sd = torch.load(f, weights_only=True) | ||
llama2_model.load_state_dict(sd) | ||
merged_ckpt_out = llama2_model(inputs) | ||
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @pbontrager had looked at something with this previously, but we still create the full state dict on CPU in this case then, right? To actually reduce the memory footprint to just the adapter weights we'd need to change the call to
get_full_model_state_dict
above (btw it's ok to say "let's do this in a follow-up")