Skip to content

Commit

Permalink
Add docs, address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Mar 4, 2024
1 parent fd9bc28 commit ba0db26
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
12 changes: 12 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ The default LoRA settings in PEFT add trainable weights to the query and value l
config = LoraConfig(target_modules="all-linear", ...)
```

### Memory efficient Layer Replication with LoRA

One of approach used to improve the performance of models is using model merging techniques is to expand a model by duplicating layers in the model to build a larger model from a pretrained model of a given size.
For example increasing a 7B model to a 10B model as described in the [SOLAR](https://arxiv.org/abs/2312.15166). PEFT LoRA supports this kind of merge in a memory efficient manner that supports further fine-tuning
using LoRA adapters attached to the layers post replication of the layers. The replicated layers do not take additional memory as they share the underlying weights so the only additional memory required is the
memory for the adapter weights. To use this feature you would create a config with the `layer_replication` argument.
```py
config = LoraConfig(layer_replication=[[0,4], [2,5]], ...)
```
Given the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a model with 7 layers arranged as `[0, 1, 2, 3, 2, 3, 4]`. This follows the mergekit pass through merge convention where sequences
of layers specified as start inclusive and end exclusive tuples are stacked to build the final model. It is important to note that each layer in the final model gets its own distinct set of LoRA adpaters.

## Merge adapters

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
Expand Down
11 changes: 6 additions & 5 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class LoraConfig(PeftConfig):
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
quantized model in this case, as LoftQ will quantize the model itself.
layer_replication(`List[Tuple[int, int]]):
Build a new stack of layers by stacking the original model layers according to the ranges specified.
This allows expanding (or shrinking) the model without duplicating the base model weights.
The new layers will all have separate LoRA adapters attached to them.
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
all have separate LoRA adapters attached to them.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -234,8 +234,9 @@ class LoraConfig(PeftConfig):
metadata={
"help": (
"This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. "
"The transformation expects a `layers` module list in the model which it modifies to expand the number of modules. "
"Base weights are shared so the memory usage is close to the original model. The inteneded use is these base weights "
"The transformation handles models (currently Llama, Bert or Falcon compatible architecutres) with "
"a module list in the model which it modifies to expand the number of modules. "
"Base weights are shared so the memory usage is close to the original model. The intended use is these base weights "
"remain fixed during finetuning but each layer has a separate LoRA adapter so the layers can be specialed via "
"the adapter layers fit during fine tuning."
"The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n"
Expand Down
4 changes: 0 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,6 @@ def _check_merge_allowed(self):
if self.peft_config.get("layer_replication"):
raise ValueError("Cannot merge LORA layers when base model layers are replicated")

def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
self._check_merge_allowed()
super().merge_adapter(adapter_names=adapter_names)

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
Expand Down
20 changes: 11 additions & 9 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
self._check_merge_allowed()
for module in self.model.modules():
if isinstance(module, BaseTunerLayer):
with onload_layer(module):
Expand Down Expand Up @@ -693,9 +694,8 @@ def check_adapters_to_merge(module: BaseTunerLayer, adapter_names: Optional[list
def clone_module(module: nn.Module, share_weights=False):
"""Clone a module in a pytorch model.
Clones a module of a model, optionally sharing all the parameters between
the original and the clone. Simplifies reusing a module when manipulating the
architecture of a model.
Clones a module of a model, optionally sharing all the parameters between the original and the clone. Simplifies
reusing a module when manipulating the architecture of a model.
"""
clone = copy.deepcopy(module)

Expand All @@ -713,10 +713,9 @@ def _share_weights(src: nn.Module, dst: nn.Module):
def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]):
"""Replicate layers in a transfomer model with weight sharing.
This function looks for a module list attribute at model[(.model)*].layers
and replicates the layers in the module list according to the layer map.
For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3, 4]`
and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`.
This function looks for a module list attribute at model[(.model)*].layers and replicates the layers in the module
list according to the layer map. For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3,
4]` and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`.
"""
while hasattr(model, "model"):
model = model.model
Expand All @@ -736,7 +735,10 @@ def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]):
model_type = "falcon"
layers = model.h
if not model_type or not isinstance(layers, nn.ModuleList):
raise ValueError("Could not locate the layers attribute in the model.")
raise ValueError(
"Could not locate the layers attribute in the model. "
"Expected Llama, Bert or Falcon compatible architectures."
)

new_layers = []
for start, end in layer_map:
Expand All @@ -755,6 +757,6 @@ def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]):
elif model_type == "falcon":
model.h = layers
else:
raise AssertionError("Unexpected model type.")
raise ValueError("Unexpected model type, need to handle post-processing of layers.")
if hasattr(model.config, "num_hidden_layers"): # Common to Llama, Bert, Falcon.
model.config.num_hidden_layers = len(new_layers)

0 comments on commit ba0db26

Please sign in to comment.