Skip to content

Commit

Permalink
FEAT: Add LoRA+ (#1915)
Browse files Browse the repository at this point in the history
Add LoRA+: Efficient Low Rank Adaptation of Large Models

https://arxiv.org/abs/2402.12354

Call create_loraplus_optimizer to initialize an optimizer with optimizer
parameters that are especially effective for LoRA training.

Builds upon this code base:

https://github.com/nikhil-ghosh-berkeley/loraplus

---------

Co-authored-by: moghadas76 <s.m.moghadas2012@gmail.com>
Co-authored-by: Chris Hua <stillmatic@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent 296fbcd commit 273acf0
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 0 deletions.
33 changes: 33 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,39 @@ Assuming the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a
[Fewshot-Metamath-OrcaVicuna-Mistral-10B](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B) is an example of a model trained using this method on Mistral-7B expanded to 10B. The
[adapter_config.json](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B/blob/main/adapter_config.json) shows a sample LoRA adapter config applying this method for fine-tuning.

## Optimizers

LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.

### LoRA+ optimized LoRA

LoRA training can be optimized using [LoRA+](https://arxiv.org/abs/2402.12354), which uses different learning rates for the adapter matrices A and B, shown to increase finetuning speed by up to 2x and performance by 1-2%.

```py
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
from transformers import Trainer
import bitsandbytes as bnb

base_model = ...
config = LoraConfig(...)
model = get_peft_model(base_model, config)

optimizer = create_loraplus_optimizer(
model=model,
optimizer_cls=bnb.optim.Adam8bit,
lr=5e-5,
loraplus_lr_ratio=16,
)
scheduler = None

...
trainer = Trainer(
...,
optimizers=(optimizer, scheduler),
)
```

## Merge LoRA weights into the base model

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
Expand Down
1 change: 1 addition & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"P_TUNING": PromptEncoderConfig,
"LORA": LoraConfig,
"LOHA": LoHaConfig,
"LORAPLUS": LoraConfig,
"LOKR": LoKrConfig,
"ADALORA": AdaLoraConfig,
"BOFT": BOFTConfig,
Expand Down
18 changes: 18 additions & 0 deletions src/peft/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .loraplus import create_loraplus_optimizer


__all__ = ["create_loraplus_optimizer"]
120 changes: 120 additions & 0 deletions src/peft/optimizers/loraplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains the implementation of the LoraPlus optimizer.
"""

from __future__ import annotations

from operator import attrgetter

import torch.nn as nn
from torch.optim import Optimizer
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names

from ..peft_model import PeftModel
from ..tuners.lora.layer import Embedding


def create_loraplus_optimizer(
model: PeftModel, optimizer_cls: type[Optimizer], *, lr: float, loraplus_lr_ratio: float, **kwargs
) -> Optimizer:
"""
Creates a LoraPlus optimizer.
Efficient Low Rank Adaptation of Large Models: https://arxiv.org/abs/2402.12354
Reference: https://github.com/nikhil-ghosh-berkeley/loraplus/
Args:
model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
lr (`float`): The learning rate to be used for the optimizer.
loraplus_lr_ratio (`float`):
The ratio of learning ηB/ηA where ηA (lr) is passed in as the optimizer learning rate. Should be ≥1. Should
be set in tandem with the optimizer learning rate (lr); should be larger when the task is more difficult
and the model needs to update its features to learn well. In this case, it helps to make the learning rate
slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates
loraplus_lr_embedding (optional `float`):
If LoRA modules are added to embedding layers your can specify a different learning rate for them. Default
value 1e-6.
kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
Returns:
`torch.optim.Optimizer`: An instance of the specified optimizer class configured with the model's parameters
organized into groups with custom learning rates.
"""

decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}

for name, param in model.named_parameters():
if not param.requires_grad:
continue

module = attrgetter(name)(model)
if isinstance(module, Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param

loraplus_weight_decay = kwargs.pop("loraplus_weight_decay", 0.0)
loraplus_lr_embedding = kwargs.pop("loraplus_lr_embedding", 1e-6)

optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": loraplus_weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": loraplus_weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": loraplus_weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]

optimizer = optimizer_cls(optimizer_grouped_parameters, **kwargs)
eight_bit_names = ["Adam8bit", "AdamW8bit", "PagedAdam8bit", "PagedAdamW8bit"]
if optimizer_cls.__name__ in eight_bit_names:
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
return optimizer
99 changes: 99 additions & 0 deletions tests/test_loraplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import torch
from torch import nn

from peft.import_utils import is_bnb_available
from peft.optimizers import create_loraplus_optimizer

from .testing_utils import require_bitsandbytes


if is_bnb_available():
import bitsandbytes as bnb


class SimpleNet(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.embedding = nn.Embedding(100, 20)
self.layer_norm = nn.LayerNorm(20)
self.lin0 = nn.Linear(20, 20, bias=bias)
self.relu = nn.ReLU()
self.lin1 = nn.Linear(20, 16, bias=bias)

def forward(self, X):
X = self.lin0(self.layer_norm(self.embedding(X)))
X = self.relu(X)
X = self.lin1(X)
return X


@require_bitsandbytes
def test_lora_plus_helper_sucess():
model = SimpleNet()
optimizer_cls = bnb.optim.Adam8bit
lr = 5e-5
optim_config = {
"eps": 1e-6,
"betas": (0.9, 0.999),
"loraplus_weight_decay": 0.0,
}
loraplus_lr_ratio = 1.2
loraplus_lr_embedding = 1e-6
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
lr=lr,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optim_config,
)
assert optim is not None
assert len(optim.param_groups) == 4
assert optim.param_groups[0]["lr"] == lr
assert optim.param_groups[1]["lr"] == loraplus_lr_embedding
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (lr * loraplus_lr_ratio)


@require_bitsandbytes
def test_lora_plus_optimizer_sucess():
"""
Test if the optimizer is correctly created and step function runs without any exception
"""
optimizer_cls = bnb.optim.Adam8bit
optim_config = {
"eps": 1e-6,
"betas": (0.9, 0.999),
"loraplus_weight_decay": 0.0,
}
model: SimpleNet = SimpleNet().cuda()
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
lr=5e-5,
loraplus_lr_ratio=1.2,
loraplus_lr_embedding=1e-6,
**optim_config,
)
loss = torch.nn.CrossEntropyLoss()
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
x = torch.randint(100, (2, 4, 10)).cuda()
output = model(x).permute(0, 3, 1, 2)
label = torch.randint(16, (2, 4, 10)).cuda()
loss_value = loss(output, label)
loss_value.backward()
optim.step()

0 comments on commit 273acf0

Please sign in to comment.