Skip to content

Commit

Permalink
Refactor to simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs committed May 1, 2024
1 parent 36d7f28 commit 3dd8e12
Showing 1 changed file with 17 additions and 47 deletions.
64 changes: 17 additions & 47 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,67 +400,37 @@ def delete_adapter(self, adapter_name: str) -> None:

self.active_adapter = new_adapter or []

def _check_add_weighted_adapter(self, adapters: list[str]) -> str:
def _check_add_weighted_adapter(self, adapters: list[str]) -> tuple[str, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
# Validate existence of adapters
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
if adapter not in self.peft_config:
raise ValueError(f"Adapter {adapter} does not exist")

# If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
# modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
# have modules for the adapters to be merged.
# Check for conflicting modules_to_save
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
problematic_wrappers = [
wrapper
for wrapper in modules_to_save_wrappers
if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1
]
if problematic_wrappers:
raise ValueError(
"Cannot add weighted adapters if they target the same module with modules_to_save, but found "
f"{len(problematic_wrappers)} such instance(s)."
)
if any(sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 for wrapper in modules_to_save_wrappers):
raise ValueError("Cannot add weighted adapters targeting the same module with modules_to_save.")

target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
if not target_module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(target_module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
)
# Ensure all adapters have compatible target and feedforward module types
target_module_types = {type(self.peft_config[adapter].target_modules) for adapter in adapters}
feedforward_module_types = {type(self.peft_config[adapter].feedforward_modules) for adapter in adapters}
if len(target_module_types) > 1 or len(feedforward_module_types) > 1:
raise ValueError("All adapter configs should have the same type for target and feedforward modules.")

if target_module_types[0] == str:
# Combine target and feedforward modules
if str in target_module_types:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
elif target_module_types[0] == set:
new_target_modules = reduce(
operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))

feedforward_modules = [type(self.peft_config[adapter].feedforward_modules) for adapter in adapters]
if not feedforward_modules:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(feedforward_modules)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `feedforward_modules` type being a mix of list/set and string is not supported."
)

if feedforward_modules[0] == str:
new_feedforward_modules = "|".join(
f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters
)
elif feedforward_modules[0] == set:
new_feedforward_modules = reduce(
operator.or_, (self.peft_config[adapter].feedforward_modules for adapter in adapters)
)
if str in feedforward_module_types:
new_feedforward_modules = "|".join(f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters)
else:
raise TypeError(f"Invalid type {feedforward_modules[0]} found in feedforward_modules")
new_feedforward_modules = set.union(*(self.peft_config[adapter].feedforward_modules for adapter in adapters))

return new_target_modules, new_feedforward_modules

Expand Down

0 comments on commit 3dd8e12

Please sign in to comment.