Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save adapter weights only #1764

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,34 +401,21 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v.cpu() for k, v in get_adapter_params(self._model).items()
}

ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

self._checkpointer.save_checkpoint(
ckpt_dict,
Expand Down
17 changes: 3 additions & 14 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,34 +577,23 @@ def save_checkpoint(self, epoch: int) -> None:
}
)

adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})

if not self._save_adapter_weights_only:
# Construct the full state dict with LoRA weights merged into base LLM weights

# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}

# Construct the adapter weights
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in state_dict.items() if adapter_key_filter(k)
}

merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)

ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
else:
# No need to merge state dict if we're only saving adapter weights
adapter_state_dict = {
k: v.cpu() for k, v in get_adapter_params(self._model).items()
}

ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
Expand Down
4 changes: 3 additions & 1 deletion tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"batch_size=8",
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
Expand Down Expand Up @@ -83,6 +82,7 @@ def test_training_state_on_resume(
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down Expand Up @@ -112,6 +112,7 @@ def test_training_state_on_resume(
metric_logger.filename={resumed_log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down
5 changes: 4 additions & 1 deletion tests/recipes/test_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class TestLoRAFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
return [
"batch_size=4",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
"epochs=2",
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch):
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
reshard_after_forward={reshard_after_forward} \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down Expand Up @@ -147,6 +147,7 @@ def test_training_state_on_resume(
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand All @@ -171,6 +172,7 @@ def test_training_state_on_resume(
tokenizer.prompt_template=null \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
Expand Down Expand Up @@ -213,6 +215,7 @@ def test_save_and_load_merged_weights(
checkpointer.model_type={model_type.upper()} \
tokenizer.path='{tokenizer_path}' \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down
5 changes: 4 additions & 1 deletion tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"batch_size=8",
"device=cpu",
f"dtype={dtype_str}",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
"seed=9",
f"epochs={epochs}",
Expand Down Expand Up @@ -133,6 +132,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch):
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
compile={compile} \
enable_activation_checkpointing=False \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_qlora"]
Expand Down Expand Up @@ -188,6 +188,7 @@ def test_training_state_on_resume(
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand All @@ -213,6 +214,7 @@ def test_training_state_on_resume(
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
Expand Down Expand Up @@ -244,6 +246,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]
Expand Down
Loading