Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add_weighted_adapter to IA3 adapters #1701

Merged
merged 7 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import re
import warnings
from dataclasses import asdict
from dataclasses import asdict, replace
from enum import Enum
from typing import Optional

Expand All @@ -29,6 +29,7 @@
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_freeze_adapter,
_get_submodules,
)

Expand Down Expand Up @@ -279,17 +280,20 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _prepare_adapter_config(self, peft_config, model_config):
@staticmethod
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #980 (comment), this seemed to be a bug.

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
if peft_config.feedforward_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
raise ValueError("Please specify `feedforward_modules` in `peft_config`")
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
model_config["model_type"]
]
peft_config.feedforward_modules = set(
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config

def _unload_and_optionally_merge(
Expand Down Expand Up @@ -393,3 +397,94 @@ def delete_adapter(self, adapter_name: str) -> None:
new_adapter = target.active_adapters[:]

self.active_adapter = new_adapter or []

def _check_add_weighted_adapter(self, adapters: list[str]) -> tuple[str, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
# Validate existence of adapters
for adapter in adapters:
if adapter not in self.peft_config:
raise ValueError(f"Adapter {adapter} does not exist")

# Check for conflicting modules_to_save
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
if any(
sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 for wrapper in modules_to_save_wrappers
):
raise ValueError("Cannot add weighted adapters targeting the same module with modules_to_save.")

# Ensure all adapters have compatible target and feedforward module types
target_module_types = {type(self.peft_config[adapter].target_modules) for adapter in adapters}
feedforward_module_types = {type(self.peft_config[adapter].feedforward_modules) for adapter in adapters}
if len(target_module_types) > 1 or len(feedforward_module_types) > 1:
raise ValueError("All adapter configs should have the same type for target and feedforward modules.")

# Combine target and feedforward modules
if str in target_module_types:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
else:
new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, simpler than the reduce op we have in LoRA 👍


if str in feedforward_module_types:
new_feedforward_modules = "|".join(
f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters
)
else:
new_feedforward_modules = set.union(
*(self.peft_config[adapter].feedforward_modules for adapter in adapters)
)

return new_target_modules, new_feedforward_modules

def add_weighted_adapter(
self,
adapters: list[str],
weights: list[float],
adapter_name: str,
) -> None:
"""
This method adds a new adapter by merging the given adapters with the given weights.

Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
"""
if adapter_name in list(self.peft_config.keys()):
return

new_target_modules, new_feedforward_modules = self._check_add_weighted_adapter(
adapters=adapters,
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
target_modules=new_target_modules,
feedforward_modules=new_feedforward_modules,
)
self.inject_adapter(self.model, adapter_name)

# Do we really need that?
_freeze_adapter(self.model, adapter_name)

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
if adapter_name in target.ia3_l:
target_ia3_l = target.ia3_l[adapter_name]
else:
continue

target_ia3_l.data = target_ia3_l.data * 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here we could use:

Suggested change
target_ia3_l.data = target_ia3_l.data * 0.0
target_ia3_l.data = target_ia3_l.data.zero_()

but I tried to follow the code style used in LoRA:

target_lora_A.data = target_lora_A.data * 0.0

https://pytorch.org/docs/stable/generated/torch.Tensor.zero_.html#torch-tensor-zero

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably .zero_() is more efficient (or has the potential to be so), so I'd be fine with this change.

for adapter, weight in zip(adapters, weights):
if adapter in target.ia3_l:
current_adapter_ia3_l = target.ia3_l[adapter]
else:
continue
target_ia3_l.data += current_adapter_ia3_l.data * weight
3 changes: 0 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,6 @@ def add_weighted_adapter(

if adapter_name in list(self.peft_config.keys()):
return
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
Comment on lines -578 to -580
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done in _check_add_weighted_adapter

for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
, which is called right after


combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
Expand Down
1 change: 1 addition & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "CAUSAL_LM",
},
Expand Down
1 change: 1 addition & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "FEATURE_EXTRACTION",
},
Expand Down
Loading