diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index 4f5b119f34..5777581e9b 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -26,7 +26,8 @@ class AdaLoraLayer(LoraLayer): # List all names of layers that may contain adapter weights # Note: ranknum doesn't need to be included as it is not an nn.Module - adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"] + adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B") + # other_param_names is defined in LoraLayer def __init__( self, diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index b4ff69cc64..cd278a450a 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -25,8 +25,10 @@ class IA3Layer(BaseTunerLayer): - # List all names of layers that may contain adapter weights - adapter_layer_names = ["ia3_l"] + # All names of layers that may contain adapter weights + adapter_layer_names = ("ia3_l",) + # All names of other parameters that may contain adapter-related parameters + other_layer_names = ("scaling",) def __init__( self, diff --git a/src/peft/tuners/loha/layer.py b/src/peft/tuners/loha/layer.py index 26f57ac681..2a8a205b02 100644 --- a/src/peft/tuners/loha/layer.py +++ b/src/peft/tuners/loha/layer.py @@ -24,8 +24,9 @@ class LoHaLayer(LycorisLayer, nn.Module): - # List all names of layers that may contain adapter weights - adapter_layer_names = ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2"] + # All names of layers that may contain adapter weights + adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2") + # other_param_names is defined on parent class def __init__(self): LycorisLayer.__init__(self) diff --git a/src/peft/tuners/lokr/layer.py b/src/peft/tuners/lokr/layer.py index 9b01ecf96f..97f3afb6fd 100644 --- a/src/peft/tuners/lokr/layer.py +++ b/src/peft/tuners/lokr/layer.py @@ -24,8 +24,8 @@ class LoKrLayer(LycorisLayer, nn.Module): - # List all names of layers that may contain adapter weights - adapter_layer_names = [ + # All names of layers that may contain adapter weights + adapter_layer_names = ( "lokr_w1", "lokr_w1_a", "lokr_w1_b", @@ -33,7 +33,8 @@ class LoKrLayer(LycorisLayer, nn.Module): "lokr_w2_a", "lokr_w2_b", "lokr_t2", - ] + ) + # other_param_names is defined on parent class def __init__(self): LycorisLayer.__init__(self) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 0eb2efa2f2..ab9eb83fcc 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -26,8 +26,10 @@ class LoraLayer(BaseTunerLayer): - # List all names of layers that may contain adapter weights - adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"] + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") def __init__(self, in_features: int, out_features: int, **kwargs): self.r = {} diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 666c611c0e..1fe5322c3f 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -661,29 +661,16 @@ def delete_adapter(self, adapter_name: str): del self.peft_config[adapter_name] key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + new_adapter = None for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, LoraLayer): - for attr in [ - "r", - "lora_alpha", - "scaling", - "lora_A", - "lora_B", - "lora_embedding_A", - "lora_embedding_B", - "lora_dropout", - ]: - if adapter_name in getattr(target, attr): - getattr(target, attr).pop(adapter_name) - if adapter_name in target.active_adapters: - resetting_active_adapter = ( - list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default" - ) - warnings.warn( - f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. " - ) - target.set_adapter(resetting_active_adapter) + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + if new_adapter: + self.active_adapter = new_adapter def merge_and_unload(self, progressbar: bool = False, safe_merge: bool = False): r""" diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 8d3fb7481b..3f7f9ac02e 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -62,6 +62,8 @@ class LycorisLayer(BaseTunerLayer, nn.Module): r""" A base layer for LyCORIS like adapters """ + # adapter_layer_names needs to be defined on the child class + other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout") def __init__(self): self.r = {} @@ -391,17 +393,13 @@ def delete_adapter(self, adapter_name: str): del self.peft_config[adapter_name] key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, LycorisLayer): - for attr in target.adapter_layer_names: - if adapter_name in getattr(target, attr): - getattr(target, attr).pop(adapter_name) - if adapter_name in target.active_adapters: - resetting_active_adapter = ( - list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default" - ) - warnings.warn( - f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. " - ) - target.set_adapter(resetting_active_adapter) + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + if new_adapter: + self.active_adapter = new_adapter diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 8ca3abfa56..c449ec2d47 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -16,6 +16,7 @@ import logging import re +import warnings from abc import ABC, abstractmethod from typing import Any, Union @@ -272,8 +273,10 @@ class BaseTunerLayer(ABC): """ active_adapter = None - # List all names of layers that may contain adapter weights - adapter_layer_names: list[str] = [] + # All names of layers that may contain adapter (trainable) weights + adapter_layer_names: tuple[str] = () + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str] = () # indicates whether all adapters should be disabled _disable_adapters: bool = False @@ -351,6 +354,54 @@ def set_adapter(self, adapter_names: str | list[str]): self._active_adapter = adapter_names + def _all_available_adapter_names(self) -> list[str]: + """Return a sorted list of all available adapter names""" + adapter_names = set() + for name in self.adapter_layer_names + self.other_param_names: + # we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter + # names + attr = getattr(self, name) + if hasattr(attr, "keys"): + adapter_names.update(attr.keys()) + return sorted(adapter_names) + + def delete_adapter(self, adapter_name: str) -> None: + """ + Delete an adapter from the layer + + This should be called on all adapter layers, or else we will get an inconsistent state. + + This method will also set a new active adapter if the deleted adapter was an active adapter. It is important that + the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers. + + Args: + adapter_name (`str`): The name of the adapter to delete + + """ + for attr in self.adapter_layer_names + self.other_param_names: + if adapter_name in getattr(self, attr): + del getattr(self, attr)[adapter_name] + + if adapter_name in self.active_adapters: + # choose a new active adapter + active_adapters = self.active_adapters[:] + active_adapters.remove(adapter_name) + if active_adapters: + self.set_adapter(active_adapters) + else: + # no active adapters left, set a new default adapter + # here we get the list of all adapters existing adapter names and choose the first one + remaining_adapters = self._all_available_adapter_names() + if not remaining_adapters: + raise ValueError("You tried to delete the only adapter in the model, this is not possible.") + + new_active_adapter = remaining_adapters[0] + warnings.warn( + f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to " + f"{new_active_adapter}." + ) + self.set_adapter(remaining_adapters[0]) + def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: """A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 9bd4dec9b6..368b8ec9aa 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -681,6 +681,14 @@ def run_with_disable(config_kwargs, bias): # This is bad, there was a warning about the bias when there should not have been any. self.fail("There should be no warning when bias is set to 'none'") + @parameterized.expand(TEST_CASES) + def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand(TEST_CASES) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(TEST_CASES) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index ea30a8183c..bb8df694d7 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -154,6 +154,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index cf200399bf..25931d5b56 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -125,6 +125,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 94e2c81835..ce09fc6247 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -146,6 +146,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + @parameterized.expand( PeftTestConfigManager.get_grid_parameters( { diff --git a/tests/testing_common.py b/tests/testing_common.py index 8bb7a104cd..1322eaa76b 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -29,6 +29,7 @@ IA3Config, LoraConfig, PeftModel, + PeftType, PrefixTuningConfig, PromptEncoderConfig, PromptLearningConfig, @@ -815,42 +816,69 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar self.assertIsNotNone(param.grad) def _test_delete_adapter(self, model_id, config_cls, config_kwargs): - if issubclass(config_cls, AdaLoraConfig): - # AdaLora does not support adding more than 1 adapter + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR] + # IA3 does not support deleting adapters yet, but it just needs to be added + # AdaLora does not support multiple adapters + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.peft_type not in supported_peft_types: return model = self.transformers_class.from_pretrained(model_id) + if isinstance(config.target_modules, str): + # TODO this should be doable + self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.") + + adapter_to_delete = "delete_me" + model = get_peft_model(model, config) + model.add_adapter(adapter_to_delete, config) + model.set_adapter(adapter_to_delete) + model = model.to(self.torch_device) + model.delete_adapter(adapter_to_delete) + self.assertFalse(adapter_to_delete in model.peft_config) + self.assertEqual(model.active_adapters, ["default"]) + + key_list = [key for key, _ in model.named_modules() if "lora" not in key] + for key in key_list: + _, target, _ = _get_submodules(model, key) + attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", []) + for attr in attributes_to_check: + self.assertFalse(adapter_to_delete in getattr(target, attr)) + + def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): + # same as test_delete_adapter, but this time an inactive adapter is deleted + supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR] + # IA3 does not support deleting adapters yet, but it just needs to be added + # AdaLora does not support multiple adapters config = config_cls( base_model_name_or_path=model_id, **config_kwargs, ) + if config.peft_type not in supported_peft_types: + return + + model = self.transformers_class.from_pretrained(model_id) + if isinstance(config.target_modules, str): + # TODO this should be doable + self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.") + adapter_to_delete = "delete_me" model = get_peft_model(model, config) model.add_adapter(adapter_to_delete, config) - model.set_adapter(adapter_to_delete) + # "delete_me" is added but not activated model = model.to(self.torch_device) + model.delete_adapter(adapter_to_delete) + self.assertFalse(adapter_to_delete in model.peft_config) + self.assertEqual(model.active_adapters, ["default"]) - if config.peft_type not in ("LORA"): - with self.assertRaises(AttributeError): - model.delete_adapter(adapter_to_delete) - else: - model.delete_adapter(adapter_to_delete) - self.assertFalse(adapter_to_delete in model.peft_config) - key_list = [key for key, _ in model.named_modules() if "lora" not in key] - for key in key_list: - _, target, _ = _get_submodules(model, key) - if isinstance(target, LoraLayer): - for attr in [ - "r", - "lora_alpha", - "scaling", - "lora_A", - "lora_B", - "lora_embedding_A", - "lora_embedding_B", - "lora_dropout", - ]: - self.assertFalse(adapter_to_delete in getattr(target, attr)) + key_list = [key for key, _ in model.named_modules() if "lora" not in key] + for key in key_list: + _, target, _ = _get_submodules(model, key) + attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", []) + for attr in attributes_to_check: + self.assertFalse(adapter_to_delete in getattr(target, attr)) def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id)