Skip to content

Commit

Permalink
address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs committed May 7, 2024
1 parent faa30e7 commit e4bc355
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
8 changes: 7 additions & 1 deletion docs/source/developer_guides/model_merging.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ print(tokenizer.decode(outputs[0]))


## Merging (IA)³ Models
(IA)³ models support linear model merging. To merge (IA)³ models, you can use the [`~IA3Model.add_weighted_adapter`] method. This method is similar to the [`~LoraModel.add_weighted_adapter`] method, but it doesn't accept the `combination_type` parameter. Assuming we have a PEFT model and three (IA)³ adapters, we can merge them as follows:
The (IA)³ models facilitate linear merging of adapters. To merge adapters in an (IA)³ model, utilize the `add_weighted_adapter` method from the `IA3Model` class. This method is analogous to the `add_weighted_adapter` method used in `LoraModel`, with the key difference being the absence of the `combination_type` parameter. For example, to merge three (IA)³ adapters into a PEFT model, you would proceed as follows:

```py
adapters = ["adapter1", "adapter2", "adapter3"]
weights = [0.4, 0.3, 0.3]
adapter_name = "merge"
model.add_weighted_adapter(adapters, weights, adapter_name)
```

It is recommended that the weights sum to 1.0 to preserve the scale of the model. The merged model can then be set as the active model using the `set_adapter` method:

```py
model.set_adapter("merge")
```
27 changes: 10 additions & 17 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,11 +1198,7 @@ def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_lis

# test ties re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2],
weight_list[:2],
"multi_adapter_ties_reweighting",
combination_type="ties",
density=0.5,
adapter_list[:2], weight_list[:2], "multi_adapter_ties_reweighting", combination_type="ties", density=0.5
)

# test dare_linear re-weighting with multiple adapters
Expand Down Expand Up @@ -1367,20 +1363,17 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
**config_kwargs,
)

# Define a dictionary to map config types to their respective test functions
test_functions = {
LoraConfig: self._test_weighted_combination_of_adapters_lora,
IA3Config: self._test_weighted_combination_of_adapters_ia3,
}
if not isinstance(config, (LoraConfig, IA3Config)):
# This test is only applicable for Lora and IA3 configs
return pytest.skip(f"Test not applicable for {config}")

# Get the test function based on the config type
test_function = test_functions.get(type(config))
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])

if test_function:
# Only instantiate the model if a valid config is provided
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])
test_function(model, config, adapter_list, weight_list)
if isinstance(config, LoraConfig):
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list)
elif isinstance(config, IA3Config):
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list)
else:
pytest.skip(f"Test not applicable for {config}")

Expand Down

0 comments on commit e4bc355

Please sign in to comment.