Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Feb 19, 2024
1 parent f0cbcd7 commit 547cade
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
12 changes: 8 additions & 4 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import Literal, Optional, Union

from peft.config import PeftConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -229,13 +229,17 @@ class LoraConfig(PeftConfig):
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[List[Tuple[int, int]]] = field(
layer_replication: Optional[list[tuple[int, int]]] = field(
default=None,
metadata={
"help": (
"This enables using LoRA to effectively expand a model to a larger size by repeating some layers. "
"This enables using LoRA to effectively expand a transformer model to a larger size by repeating some layers. "
"The transformation expects a `layers` module list in the model which it modifies to expand the number of modules. "
"Base weights are shared so the memory usage is close to the original model."
"The format is a list of (start, end) pairs which specify the layer ranges to stack."
"The format is a list of [start, end) pairs which specify the layer ranges to stack. For example:\n"
" original: `[0, 1, 2, 3, 4]`\n"
" layer_replication: `[[0, 4], [2, 5]]`\n"
" final: `[0, 1, 2, 3, 2, 3, 4]`"
)
}
)
Expand Down
4 changes: 4 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
self.active_adapter = adapter_name

def _check_merge_allowed(self):
"""Verify that the configuration supports merging.
Currently gptq quantization and replicated layers do not support merging.
"""
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
if self.peft_config.get('layer_replication'):
Expand Down
26 changes: 22 additions & 4 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import torch
from accelerate.hooks import AlignDevicesHook
Expand Down Expand Up @@ -257,6 +257,13 @@ def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
pass

def _check_merge_allowed(self):
"""Helper method to check whether the adapter can be merged.
Raise a ValueError if it is not possible to merge the adapter with the given configuration.
"""
pass

def inject_adapter(self, model: nn.Module, adapter_name: str):
r"""
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
Expand Down Expand Up @@ -697,16 +704,27 @@ def _share_weights(src: nn.Module, dst: nn.Module):
return clone


def replicate_layers(model: nn.Module, layer_map: List[Tuple[int, int]]):
def replicate_layers(model: nn.Module, layer_map: list[tuple[int, int]]):
"""Replicate layers in a transfomer model with weight sharing.
This function looks for a module list attribute at model[(.model)*].layers
and replicates the layers in the module list according to the layer map.
For example the map `[[0, 4], [2, 5]]` will take the set of layers `[0, 1, 2, 3, 4]`
and replace them with a module list containing `[0, 1, 2, 3, 2, 3, 4]`.
"""
while hasattr(model, 'model'):
model = model.model
if not hasattr(model, 'layers'):
raise ValueError('Could not locate the layers attribute in the model.')
new_layers = []
for start, end in layer_map:
for i in range(start, end):
current_idx = len(new_layers)
new_layers.append(clone_module(model.base_model.layers[i], share_weights=True))
new_layers.append(clone_module(model.layers[i], share_weights=True))
# This is a hack needed to work around the layer_idx introduced in HF transformers.
for submodule in new_layers[-1].modules():
if hasattr(submodule, 'layer_idx'):
submodule.layer_idx = current_idx
model.base_model.layers = nn.ModuleList(new_layers)
model.layers = nn.ModuleList(new_layers)
if hasattr(model.config, 'num_hidden_layers'):
model.config.num_hidden_layers = len(new_layers)
7 changes: 5 additions & 2 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,11 @@ def test_lora_layer_replication(self):
base_model_name_or_path=model_id,
**config_kwargs,
)
assert 2 == len(model.model.layers), 'Expected 2 layers in original model.'
model = get_peft_model(model, config)
self.assertEquals(4, len(model.base_model.model.model.layers), 'Expected 4 layers in adapted model.')
self.assertEquals(8, len([n for n, _ in model.named_parameters() if '.lora_A.' in n]))
assert 4 == len(model.base_model.model.model.layers), 'Expected 4 layers in adapted model.'
assert 8 == len([n for n, _ in model.named_parameters() if '.lora_A.' in n]), (
'Expected 8 LoRA adapters since we are adding one each for up and down.'
)
self._test_prepare_for_training(model_id, LoraConfig, config_kwargs)
self._test_generate(model_id, LoraConfig, config_kwargs)

0 comments on commit 547cade

Please sign in to comment.