Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only merge model weights in LoRA recipe when save_adapter_weights_only=False #1476

Merged
merged 11 commits into from
Sep 15, 2024
13 changes: 7 additions & 6 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,13 @@ def save_checkpoint(
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
if not self._save_adapter_weights_only:
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
43 changes: 28 additions & 15 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,22 +410,35 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the full state dict with LoRA weights merged into base LLM weights
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})

# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v
for k, v in self._model.state_dict().items()
if adapter_key_filter(k)
}

ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

self._checkpointer.save_checkpoint(
Expand Down
13 changes: 7 additions & 6 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,13 @@ def save_checkpoint(
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
if not self._save_adapter_weights_only:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @pbontrager had looked at something with this previously, but we still create the full state dict on CPU in this case then, right? To actually reduce the memory footprint to just the adapter weights we'd need to change the call to get_full_model_state_dict above (btw it's ok to say "let's do this in a follow-up")

merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
Expand Down
60 changes: 36 additions & 24 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,16 @@ def _setup_data(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
),
)

Expand Down Expand Up @@ -527,24 +529,34 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

# Construct the full state dict with LoRA weights merged into base LLM weights
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v
for k, v in self._model.state_dict().items()
if adapter_key_filter(k)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this still materialize the entire state dict though?


ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
Expand Down
1 change: 1 addition & 0 deletions tests/assets/stack_exchange_paired_tiny.json

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're a hero

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just afraid someone would ask me to run the recipe manually, so this felt like less effort.

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import runpy
import sys
from pathlib import Path

import pytest
import torch
from omegaconf import OmegaConf
from tests.common import TUNE_PATH
from tests.recipes.utils import (
dummy_stack_exchange_dataset_config,
MODEL_TEST_CONFIGS,
write_hf_ckpt_config,
)
from tests.test_utils import (
CKPT_MODEL_PATHS,
gen_log_file_name,
get_loss_values_from_metric_logger,
)
from torchtune import config


class TestLoRADPOSingleDeviceRecipe:
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
return [
"batch_size=8",
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
"max_steps_per_epoch=2",
"optimizer.lr=2e-5",
"log_every_n_steps=1",
"gradient_accumulation_steps=1",
"clip_grad_norm=100",
"tokenizer.max_seq_len=512",
] + dummy_stack_exchange_dataset_config()

@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
@pytest.mark.integration_test
def test_training_state_on_resume(
self, tmpdir, monkeypatch, save_adapter_weights_only
):
"""Test whether the recipe state is correctly updated on resume. Since this
is model agnostic, we should run this on the small model only. The test
consists of three stages:
- Train a model for 2 epochs
- Resume training after epoch 1
- Make sure final loss matches the expected value of a model successfully resumed from a ckpt
Unlike `tests.recipes.test_lora_finetune_single_device`, this test does not use pre-computed loss
values to benchmark against. This test just ensures the loss values are identical when resuming.
"""

ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
# Create a second copy for training resume
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs
cmd_1 = f"""
tune run lora_dpo_single_device \
--config llama2/7B_lora_dpo_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
metric_logger.filename={log_file} \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]

cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_1)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

expected_loss_values = get_loss_values_from_metric_logger(log_file)

resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)
# Resume training
cmd_2 = f"""
tune run lora_dpo_single_device \
--config llama2/7B_lora_dpo_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
metric_logger.filename={resumed_log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Second epoch only
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)

torch.testing.assert_close(
resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-5, atol=1e-5
)

@pytest.mark.integration_test
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

cmd = f"""
tune run lora_dpo_single_device \
--config llama2/7B_lora_dpo_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]

cmd = cmd + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# Next load both the merged weights in a Llama2 base model
# and the base model weights + trained adapter weights in the LoRA Llama 2 model
# The results of calling forward on dummy inputs should be the same.
inputs = torch.randint(low=0, high=32_000, size=(2, 100))

# Build LoRA model for loading base + adapter weights separately
lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model)

# Build base llama2 model for loading merged weights
base_llama2_config = MODEL_TEST_CONFIGS["llama2"]
llama2_model = config.instantiate(
OmegaConf.from_dotlist(base_llama2_config).model
)

# Load base model and trained adapter weights into LoRA model and call fwd
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
lora_sd = torch.load(f, weights_only=True)
with open(ckpt_path, "rb") as f:
base_model_sd = torch.load(f, weights_only=True)
lora_model.load_state_dict(lora_sd, strict=False)
lora_model.load_state_dict(base_model_sd, strict=False)
baseline_out = lora_model(inputs)

# Load merged final ckpt directly into llama2 and call fwd
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
sd = torch.load(f, weights_only=True)
llama2_model.load_state_dict(sd)
merged_ckpt_out = llama2_model(inputs)
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
16 changes: 12 additions & 4 deletions tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,21 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch):
@pytest.mark.integration_test
@gpu_test(gpu_count=2)
@pytest.mark.parametrize(
"config, model_type, ckpt_type",
"config, model_type, ckpt_type, save_adapter_weights_only",
[
("llama2/7B_lora", "llama2", "hf"),
("llama3/8B_lora", "llama3", "tune"),
("llama2/7B_lora", "llama2", "hf", False),
("llama3/8B_lora", "llama3", "tune", False),
("llama2/7B_lora", "llama2", "hf", True),
],
)
def test_training_state_on_resume(
self, config, model_type, ckpt_type, tmpdir, monkeypatch
self,
config,
model_type,
ckpt_type,
tmpdir,
monkeypatch,
save_adapter_weights_only,
):
"""Test whether the recipe state is correctly updated on resume. Since this
is model agnostic, we should run this on the small model only. The test
Expand Down Expand Up @@ -139,6 +146,7 @@ def test_training_state_on_resume(
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down
Loading
Loading