Skip to content

Commit

Permalink
Improve test and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Mar 5, 2024
1 parent 3b1e693 commit 1da5e1d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
10 changes: 4 additions & 6 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,13 @@ config = LoraConfig(target_modules="all-linear", ...)

### Memory efficient Layer Replication with LoRA

One of approach used to improve the performance of models is using model merging techniques is to expand a model by duplicating layers in the model to build a larger model from a pretrained model of a given size.
For example increasing a 7B model to a 10B model as described in the [SOLAR](https://arxiv.org/abs/2312.15166). PEFT LoRA supports this kind of merge in a memory efficient manner that supports further fine-tuning
using LoRA adapters attached to the layers post replication of the layers. The replicated layers do not take additional memory as they share the underlying weights so the only additional memory required is the
memory for the adapter weights. To use this feature you would create a config with the `layer_replication` argument.
An approach used to improve the performance of models is to expand a model by duplicating layers in the model to build a larger model from a pretrained model of a given size. For example increasing a 7B model to a 10B model as described in the [SOLAR](https://arxiv.org/abs/2312.15166) paper. PEFT LoRA supports this kind of expansion in a memory efficient manner that supports further fine-tuning using LoRA adapters attached to the layers post replication of the layers. The replicated layers do not take additional memory as they share the underlying weights so the only additional memory required is the memory for the adapter weights. To use this feature you would create a config with the `layer_replication` argument.

```py
config = LoraConfig(layer_replication=[[0,4], [2,5]], ...)
```
Given the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a model with 7 layers arranged as `[0, 1, 2, 3, 2, 3, 4]`. This follows the mergekit pass through merge convention where sequences
of layers specified as start inclusive and end exclusive tuples are stacked to build the final model. It is important to note that each layer in the final model gets its own distinct set of LoRA adpaters.

Assuming the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a model with 7 layers arranged as `[0, 1, 2, 3, 2, 3, 4]`. This follows the [mergekit](https://github.com/arcee-ai/mergekit) pass through merge convention where sequences of layers specified as start inclusive and end exclusive tuples are stacked to build the final model. Each layer in the final model gets its own distinct set of LoRA adpaters.

[Fewshot-Metamath-OrcaVicuna-Mistral-10B](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B) is an example of a model trained using this method on Mistral-7B expanded to 10B. The
(adapter_config.json)[https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B/blob/main/adapter_config.json] shows a sample LoRA adapter config applying this method for fine-tuning.
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LoraConfig(PeftConfig):
ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than
pure LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
layer_replication(`List[Tuple[int, int]]):
layer_replication(`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
all have separate LoRA adapters attached to them.
Expand Down Expand Up @@ -254,7 +254,7 @@ class LoraConfig(PeftConfig):
metadata={
"help": (
"This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. "
"The transformation handles models (currently Llama, Bert or Falcon compatible architecutres) with "
"The transformation handles models (currently Llama, Bert or Falcon compatible architectures) with "
"a module list in the model which it modifies to expand the number of modules. "
"Base weights are shared so the memory usage is close to the original model. The intended use is these base weights "
"remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via "
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
Args:
peft_config (`PeftConfig`):
The prepared adapter config.
model_config (`nn.Module`):
model (`nn.Module`):
The model that is going to be adapted.
"""
if peft_config.layer_replication:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,24 @@ def test_lora_layer_replication(self):
)
assert len(model.model.layers), "Expected 2 layers in original model." == 2
model = get_peft_model(model, config)
assert len(model.base_model.model.model.layers) == 4, "Expected 4 layers in adapted model."
layers = model.base_model.model.model.layers
assert len(layers) == 4, "Expected 4 layers in adapted model."
assert (
layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() ==
layers[1].mlp.up_proj.base_layer.weight.data.storage().data_ptr() and
layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() ==
layers[3].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
), "Expected layers 0-1 and 2-3 to share weights"
assert (
layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() !=
layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr()
), "Expected layers 0 and 2 to have different weights"
assert (
layers[0].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() !=
layers[1].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() and
layers[2].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() !=
layers[3].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr()
), "Expected all LoRA adapters to have distinct weights"
assert (
len([n for n, _ in model.named_parameters() if ".lora_A." in n]) == 8
), "Expected 8 LoRA adapters since we are adding one each for up and down."
Expand Down

0 comments on commit 1da5e1d

Please sign in to comment.