Skip to content

Commit

Permalink
FIX Correctly pass low_cpu_mem_usage argument (#2253)
Browse files Browse the repository at this point in the history
There was a bug that when creating a PEFT model with the task_type
argument, the low_cpu_mem_usage argument was not passed along. This is
now fixed and unit tests for this were added.

This is a very niche bug because there is typically no need to pass
low_cpu_mem_usage=True when calling get_peft_model. Moreover, as the
option for this was only added recently (#2142) and is unreleased, few
if any users should be affected by the bug.
  • Loading branch information
BenjaminBossan authored Dec 3, 2024
1 parent 3f9ce55 commit c057589
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def get_peft_model(
if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
model,
peft_config,
adapter_name=adapter_name,
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)


Expand Down
22 changes: 22 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,28 @@ def test_load_adapter_low_cpu_mem_usage_works(self, device, inputs, lora_path, l
assert device_set_low_cpu_mem == device_set_not_low_cpu_mem
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)

@pytest.mark.parametrize("device", devices)
def test_get_peft_model_low_cpu_mem_usage_works(self, device, inputs):
# when calling get_peft_model, the PEFT weights will not be initialized on device but remain on meta
model = self.get_model().to(device)
model = get_peft_model(model, LoraConfig(target_modules="all-linear"), low_cpu_mem_usage=True)

devices_lora_weights = {p.device for n, p in model.named_parameters() if "lora_" in n}
expected = {torch.device("meta")}
assert devices_lora_weights == expected

@pytest.mark.parametrize("device", devices)
def test_get_peft_model_with_task_type_low_cpu_mem_usage_works(self, device, inputs):
# same as the previous test, but pass the task_type argument
model = self.get_model().to(device)
model = get_peft_model(
model, LoraConfig(target_modules="all-linear", task_type="CAUSAL_LM"), low_cpu_mem_usage=True
)

devices_lora_weights = {p.device for n, p in model.named_parameters() if "lora_" in n}
expected = {torch.device("meta")}
assert devices_lora_weights == expected

@pytest.mark.parametrize("device", devices)
def test_inject_adapter_low_cpu_mem_usage_works(self, device, inputs, lora_path, lora_config):
# external libs like transformers and diffusers use inject_adapter_in_model, let's check that this also works
Expand Down

0 comments on commit c057589

Please sign in to comment.