Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add_weighted_adapter to IA3 adapters #1701

Merged
merged 7 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/developer_guides/model_merging.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,14 @@ print(tokenizer.decode(outputs[0]))

</hfoption>
</hfoptions>


## Merging $(IA)^3$ Models
alexrs marked this conversation as resolved.
Show resolved Hide resolved
$(IA)^3$ models support linear model merging. To merge $(IA)^3$ models, you can use the `~IA3Model.add_weighted_adapter` method. This method is similar to the `~LoraModel.add_weighted_adapter` method, but it doesn't accept the `combination_type` parameter. Assuming we have a PEFT model and three $(IA)^3$ adapters, we can merge them as follows:
alexrs marked this conversation as resolved.
Show resolved Hide resolved

```py
adapters = ["adapter1", "adapter2", "adapter3"]
weights = [0.4, 0.3, 0.3]
adapter_name = "merge"
model.add_weighted_adapter(adapters, weights, adapter_name)
```
107 changes: 101 additions & 6 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import re
import warnings
from dataclasses import asdict
from dataclasses import asdict, replace
from enum import Enum
from typing import Optional

Expand All @@ -29,6 +29,7 @@
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_freeze_adapter,
_get_submodules,
)

Expand Down Expand Up @@ -279,17 +280,20 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _prepare_adapter_config(self, peft_config, model_config):
@staticmethod
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #980 (comment), this seemed to be a bug.

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
if peft_config.feedforward_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
raise ValueError("Please specify `feedforward_modules` in `peft_config`")
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
model_config["model_type"]
]
peft_config.feedforward_modules = set(
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config

def _unload_and_optionally_merge(
Expand Down Expand Up @@ -393,3 +397,94 @@ def delete_adapter(self, adapter_name: str) -> None:
new_adapter = target.active_adapters[:]

self.active_adapter = new_adapter or []

def _check_add_weighted_adapter(self, adapters: list[str]) -> tuple[str, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
# Validate existence of adapters
for adapter in adapters:
if adapter not in self.peft_config:
raise ValueError(f"Adapter {adapter} does not exist")

# Check for conflicting modules_to_save
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
if any(
sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 for wrapper in modules_to_save_wrappers
):
raise ValueError("Cannot add weighted adapters targeting the same module with modules_to_save.")

# Ensure all adapters have compatible target and feedforward module types
target_module_types = {type(self.peft_config[adapter].target_modules) for adapter in adapters}
feedforward_module_types = {type(self.peft_config[adapter].feedforward_modules) for adapter in adapters}
if len(target_module_types) > 1 or len(feedforward_module_types) > 1:
raise ValueError("All adapter configs should have the same type for target and feedforward modules.")

# Combine target and feedforward modules
if str in target_module_types:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
else:
new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, simpler than the reduce op we have in LoRA 👍


if str in feedforward_module_types:
new_feedforward_modules = "|".join(
f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters
)
else:
new_feedforward_modules = set.union(
*(self.peft_config[adapter].feedforward_modules for adapter in adapters)
)

return new_target_modules, new_feedforward_modules

def add_weighted_adapter(
self,
adapters: list[str],
weights: list[float],
adapter_name: str,
) -> None:
"""
This method adds a new adapter by merging the given adapters with the given weights.

Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
"""
if adapter_name in list(self.peft_config.keys()):
return

new_target_modules, new_feedforward_modules = self._check_add_weighted_adapter(
adapters=adapters,
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
target_modules=new_target_modules,
feedforward_modules=new_feedforward_modules,
)
self.inject_adapter(self.model, adapter_name)

# Do we really need that?
_freeze_adapter(self.model, adapter_name)

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
if adapter_name in target.ia3_l:
target_ia3_l = target.ia3_l[adapter_name]
else:
continue

target_ia3_l.data = target_ia3_l.data.zero_()
for adapter, weight in zip(adapters, weights):
if adapter in target.ia3_l:
current_adapter_ia3_l = target.ia3_l[adapter]
else:
continue
target_ia3_l.data += current_adapter_ia3_l.data * weight
3 changes: 0 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,6 @@ def add_weighted_adapter(

if adapter_name in list(self.peft_config.keys()):
return
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
Comment on lines -578 to -580
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done in _check_add_weighted_adapter

for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
, which is called right after


combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
Expand Down
1 change: 1 addition & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "CAUSAL_LM",
},
Expand Down
1 change: 1 addition & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "FEATURE_EXTRACTION",
},
Expand Down
80 changes: 63 additions & 17 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,22 +1139,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
assert not torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10)
assert torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)

def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return pytest.skip(f"Test not applicable for {config_cls}")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, LoraConfig):
return pytest.skip(f"Test not applicable for {config}")

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])
def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)
Expand Down Expand Up @@ -1213,7 +1198,11 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw

# test ties re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2], weight_list[:2], "multi_adapter_ties_reweighting", combination_type="ties", density=0.5
adapter_list[:2],
alexrs marked this conversation as resolved.
Show resolved Hide resolved
weight_list[:2],
"multi_adapter_ties_reweighting",
combination_type="ties",
density=0.5,
)

# test dare_linear re-weighting with multiple adapters
Expand Down Expand Up @@ -1338,6 +1327,63 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
assert model.active_adapters == [adapter_name]
model(**dummy_input)[0]

def _test_weighted_combination_of_adapters_ia3(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], config)
model = model.to(self.torch_device)

# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

# test re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_reweighting")

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_reweighting",
]
for new_adapter in new_adapters:
assert new_adapter in model.peft_config

dummy_input = self.prepare_inputs_for_testing()
model.eval()
for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
assert model.active_adapter == adapter_name
assert model.active_adapters == [adapter_name]
model(**dummy_input)[0]

def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return pytest.skip(f"Test not applicable for {config_cls}")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
# Initialize the config
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

# Define a dictionary to map config types to their respective test functions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more straightforward to read (though less elegant) if we do:

if isinstance(config, LoraConfig):
    self._test_weighted_combination_of_adapters_lora(...)
elif ...

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially did that, but if we want to call model = get_peft_model(model, config, adapter_list[0]) only for LoraConfig and IA3Config, we'd need to:

# check if config is LoraConfig or IA3Config. Skip if not.
if not isinstance(config, (LoraConfig, IA3Config)):
    pytest.skip(f"Test not applicable for {config}")

# Initialize model ...

# Call test method according to config type
if isinstance(config, LoraConfig):
    self._test_weighted_combination_of_adapters_lora(...)
elif  isinstance(config, IA3Config):
   self._test_weighted_combination_of_adapters_ia3(...)
else:
    pytest.skip(f"Test not applicable for {config}")

and it seems a bit redundant.

All in all, I don't have a strong opinion here. Let me know if you think this is more readable and I'll change it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean but honestly, for a test I don't care so much about such issues, it's more important to me that I can quickly read from top to bottom and understand what's going on.

test_functions = {
LoraConfig: self._test_weighted_combination_of_adapters_lora,
IA3Config: self._test_weighted_combination_of_adapters_ia3,
}

# Get the test function based on the config type
test_function = test_functions.get(type(config))

if test_function:
# Only instantiate the model if a valid config is provided
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])
test_function(model, config, adapter_list, weight_list)
else:
pytest.skip(f"Test not applicable for {config}")

def _test_disable_adapter(self, model_id, config_cls, config_kwargs):
task_type = config_kwargs.get("task_type")
if (task_type == "SEQ_2_SEQ_LM") and (config_cls in (PromptTuningConfig, PromptEncoderConfig)):
Expand Down
Loading