diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 098bd0d4fd..e605110bd8 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -266,7 +266,7 @@ def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> Non self.modules_to_save = set(modules_to_save) else: self.modules_to_save.update(modules_to_save) - _set_trainable(self, adapter_name) + _set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save) def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: """ diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 66eba17b3b..548b7a560b 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -967,7 +967,8 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): self.modules_to_save = set(peft_config.modules_to_save) else: self.modules_to_save.update(peft_config.modules_to_save) - _set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper + # this may add a new ModulesToSaveWrapper + _set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save) def get_layer_status(self) -> list[TunerLayerStatus]: """Get the status of each adapter layer in the model. @@ -1457,7 +1458,7 @@ def __init__( break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name) + _set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ @@ -2190,7 +2191,7 @@ def __init__( break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name) + _set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ @@ -2411,7 +2412,7 @@ def __init__( break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name) + _set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 8d0b7f7f06..43974c72ae 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -273,8 +273,10 @@ def update(self, adapter_name): context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) break - with context_manager: - self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) + + if adapter_name not in self.modules_to_save: + with context_manager: + self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): old_hook = self.modules_to_save[adapter_name]._hf_hook @@ -414,10 +416,10 @@ def _freeze_adapter(model, adapter_name): p.requires_grad = False -def _set_trainable(model, adapter_name): +def _set_trainable(model, adapter_name, modules_to_save): key_list = [key for key, _ in model.named_modules()] for key in key_list: - target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) + target_module_found = any(key.endswith(target_key) for target_key in modules_to_save) if target_module_found: parent, target, target_name = _get_submodules(model, key) if isinstance(target, ModulesToSaveWrapper): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 611b07bf97..654e8c2af4 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1529,6 +1529,26 @@ def test_multiple_adapters_seq_cls_mixed_modules_to_save_merging_adapters(self): with pytest.raises(ValueError, match=msg): model.add_weighted_adapter(["default", "other"], weights=[1.0, 1.0], adapter_name="merged") + def test_multiple_adapters_no_needless_copy_modules_to_save(self): + # See 2206 + # The problem was that we keep a "global" modules_to_save on the model which contains all possible + # modules_to_save for each adapter. When the first adapter targets embed_tokens with modules_to_save and the + # second adapter targets lm_head, then embed_tokens will create a copy of the original module for the second + # adapter, even though it's not needed. The copy still acts as expected but uses unnecessary memory. + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) + config0 = LoraConfig(modules_to_save=["embed_tokens"]) + config1 = LoraConfig(modules_to_save=["lm_head"]) + model = get_peft_model(model, config0) + model.add_adapter("other", config1) + + lm_head_keys = list(model.base_model.model.lm_head.modules_to_save.keys()) + assert lm_head_keys == ["other"] + + embed_token_keys = list(model.base_model.model.model.decoder.embed_tokens.modules_to_save.keys()) + # before the fix, this would be: ['default', 'other'] + assert embed_token_keys == ["default"] + def test_existing_model_card(self): # ensure that if there is already a model card, it is not overwritten model = MLP()