Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lora single device fine tune checkpoint saving & nan loss when use_dora=True #1909

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
"""

def __init__(self, cfg: DictConfig) -> None:

self._device = utils.get_device(device=cfg.device)
# Reduced precision logic
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
Expand Down Expand Up @@ -438,6 +437,9 @@ def _setup_model(
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
Expand Down
8 changes: 6 additions & 2 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ def test_training_state_on_resume(
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.parametrize("use_dora", [False, True])
@pytest.mark.integration_test
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch, use_dora):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand All @@ -280,7 +281,10 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
enable_activation_offloading=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
if use_dora:
model_config = MODEL_TEST_CONFIGS["llama2_dora"]
else:
model_config = MODEL_TEST_CONFIGS["llama2_lora"]

cmd = cmd + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd)
Expand Down
10 changes: 10 additions & 0 deletions tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def lora_llama2_test_config(
lora_rank: int = 8,
lora_alpha: float = 16,
quantize_base: bool = False,
use_dora: bool = False,
) -> List[str]:
return [
# Note: we explicitly use _component_ so that we can also call
Expand All @@ -154,6 +155,7 @@ def lora_llama2_test_config(
f"model.lora_alpha={lora_alpha}",
"model.lora_dropout=0.0",
f"model.quantize_base={quantize_base}",
f"model.use_dora={use_dora}",
]


Expand Down Expand Up @@ -207,6 +209,14 @@ def write_hf_ckpt_config(ckpt_dir: str):
lora_rank=8,
lora_alpha=16,
),
"llama2_dora": lora_llama2_test_config(
lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=False,
apply_lora_to_output=False,
lora_rank=8,
lora_alpha=16,
use_dora=True,
),
"llama2_qlora": lora_llama2_test_config(
lora_attn_modules=["q_proj", "k_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
Expand Down
35 changes: 20 additions & 15 deletions torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re

from typing import Any, Dict
from typing import Any, Dict, Optional

import torch

Expand Down Expand Up @@ -252,23 +252,28 @@ def tune_to_peft_adapter_weights(
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
head_dim: Optional[int] = None,
):
converted_state_dict = {}
full_mapping = {}
# Rather than recreate a separate mapping for LoRA adapter weights, we just
# re-use the _FROM_HF mapping for base model weights. We iterate over it twice:
# once to add mappings for LoRA A matrices and once to add mappings for LoRA B matrices.
for k, v in _TO_PEFT_KEYS.items():
full_mapping.update(
{
vv.replace(".weight", f".{k}.weight"): kk.replace(
".weight", f".{v}.weight"
)
for kk, vv in _FROM_HF.items()
if vv is not None
}
)
# Rather than recreate a separate mapping for LoRA adapter weights, we re-use the
# _FROM_HF mapping for base model weights. The mapping is adapted to account for:
# LoRA A matrices, LoRA B matrices and the dora magnitude parameter.
for peft_key, peft_val in _TO_PEFT_KEYS.items():
for hf_key, hf_val in _FROM_HF.items():
if hf_val is None:
continue

if peft_key == "magnitude":
# e.g. attn.q_proj.magnitude -> attn.q_proj.lora_magnitude_vector
adapter_key = hf_val.replace(".weight", f".{peft_key}")
adapter_val = hf_key.replace(".weight", f".{peft_val}")
else:
# e.g. attn.q_proj.lora_a.weight -> attn.q_proj.lora_A.weight
adapter_key = hf_val.replace(".weight", f".{peft_key}.weight")
adapter_val = hf_key.replace(".weight", f".{peft_val}.weight")

full_mapping.update({adapter_key: adapter_val})

if head_dim is None:
head_dim = dim // num_heads
Expand Down
3 changes: 2 additions & 1 deletion torchtune/modules/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def initialize_parameters(self):
_lora_a_init_params(self.lora_a)
_lora_b_init_params(self.lora_b)

@torch.no_grad()
def initialize_dora_magnitude(self):
"""
DoRA initializes the magnitude vector such that its outputs are initially
Expand All @@ -87,7 +88,7 @@ def initialize_dora_magnitude(self):
base_weight = self.weight.to(self.lora_a.weight.dtype)
lora_weight = self.lora_b.weight @ self.lora_a.weight
weight_norm = self._get_weight_norm(base_weight, lora_weight)
self.magnitude = nn.Parameter(weight_norm, requires_grad=True)
self.magnitude.copy_(weight_norm)

def _create_weight_and_bias(self):
"""
Expand Down
Loading