Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lora+ implementation #1915

Merged
merged 34 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5f54802
Add lora+ implentation
moghadas76 Feb 26, 2024
1703b91
Support LoraPlus cfg
moghadas76 Mar 17, 2024
105bb9b
Fix QA comments
moghadas76 Mar 31, 2024
77e04ff
Fix test
moghadas76 Mar 31, 2024
e6e3979
Fix tests
moghadas76 Apr 19, 2024
7b64437
Fix comments
moghadas76 May 16, 2024
1ceab2b
Add unit test
moghadas76 May 17, 2024
27b375b
Decouple file structures
moghadas76 May 18, 2024
4a385d4
Fix clean code issues
moghadas76 May 21, 2024
6228919
Fix styling problem
moghadas76 May 21, 2024
083b0b9
Fix formatter
moghadas76 May 22, 2024
1a75154
Fix QA comments
moghadas76 Jun 9, 2024
132dfac
Fix docs
moghadas76 Jun 9, 2024
1d35a72
style fixes
kallewoof Jul 2, 2024
65069de
move loraplus_lr_ratio out of opt_kwargs, and other fixes
kallewoof Jul 2, 2024
9146816
add support for other 8 bit optimizers
kallewoof Jul 2, 2024
4bca4a7
revert unneeded import
kallewoof Jul 4, 2024
a5df785
clean out old code
kallewoof Jul 9, 2024
af584cb
add 'LORAPLUS' to peft config mapping
kallewoof Jul 15, 2024
56a5227
conditional bnb in lora+ tests
kallewoof Jul 15, 2024
b6d2d6a
added compat for py 3.8
kallewoof Jul 16, 2024
95b9e80
license header
kallewoof Jul 16, 2024
0460dcd
review fixes
kallewoof Jul 16, 2024
80f7392
review fixes
kallewoof Jul 17, 2024
52d8e0b
make lr and loraplus_lr_ratio required forced kw args
kallewoof Jul 18, 2024
f550e51
lora+: do not propagate weight_decay to optimizer
kallewoof Jul 19, 2024
b2f802b
doc: add LoRA+ optimizer example
kallewoof Jul 22, 2024
29b3f5d
remove LoraPlusConfig
kallewoof Jul 22, 2024
d445dad
rename weight_decay to loraplus_weight_decay to avoid potential ambig…
kallewoof Jul 22, 2024
bf54d0a
review fixes LoRA+ documentation
kallewoof Jul 23, 2024
5981ecc
specify optimizer params
kallewoof Jul 23, 2024
afe07e1
assumption about lora set up
kallewoof Jul 23, 2024
c60ca71
add lora hint to lora+ example code
kallewoof Jul 23, 2024
babd89b
missing LoraConfig import
kallewoof Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,39 @@ Assuming the original model had 5 layers `[0, 1, 2 ,3, 4]`, this would create a
[Fewshot-Metamath-OrcaVicuna-Mistral-10B](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B) is an example of a model trained using this method on Mistral-7B expanded to 10B. The
[adapter_config.json](https://huggingface.co/abacusai/Fewshot-Metamath-OrcaVicuna-Mistral-10B/blob/main/adapter_config.json) shows a sample LoRA adapter config applying this method for fine-tuning.

## Optimizers

LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.

### LoRA+ optimized LoRA

LoRA training can be optimized using [LoRA+](https://arxiv.org/abs/2402.12354), which uses different learning rates for the adapter matrices A and B, shown to increase finetuning speed by up to 2x and performance by 1-2%.

```py
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
from transformers import Trainer
import bitsandbytes as bnb

base_model = ...
config = LoraConfig(...)
model = get_peft_model(base_model, config)

optimizer = create_loraplus_optimizer(
model=model,
optimizer_cls=bnb.optim.Adam8bit,
lr=5e-5,
loraplus_lr_ratio=16,
)
scheduler = None

...
trainer = Trainer(
...,
optimizers=(optimizer, scheduler),
)
```

## Merge LoRA weights into the base model

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the [`~LoraModel.merge_and_unload`] function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The [`~LoraModel.merge_and_unload`] function doesn't keep the adapter weights in memory.
Expand Down
1 change: 1 addition & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"P_TUNING": PromptEncoderConfig,
"LORA": LoraConfig,
"LOHA": LoHaConfig,
"LORAPLUS": LoraConfig,
"LOKR": LoKrConfig,
"ADALORA": AdaLoraConfig,
"BOFT": BOFTConfig,
Expand Down
18 changes: 18 additions & 0 deletions src/peft/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .loraplus import create_loraplus_optimizer


__all__ = ["create_loraplus_optimizer"]
120 changes: 120 additions & 0 deletions src/peft/optimizers/loraplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
kallewoof marked this conversation as resolved.
Show resolved Hide resolved
This module contains the implementation of the LoraPlus optimizer.
"""

from __future__ import annotations

from operator import attrgetter

import torch.nn as nn
from torch.optim import Optimizer
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names

from ..peft_model import PeftModel
from ..tuners.lora.layer import Embedding


def create_loraplus_optimizer(
model: PeftModel, optimizer_cls: type[Optimizer], *, lr: float, loraplus_lr_ratio: float, **kwargs
) -> Optimizer:
"""
Creates a LoraPlus optimizer.

Efficient Low Rank Adaptation of Large Models: https://arxiv.org/abs/2402.12354

Reference: https://github.com/nikhil-ghosh-berkeley/loraplus/

Args:
model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
lr (`float`): The learning rate to be used for the optimizer.
loraplus_lr_ratio (`float`):
The ratio of learning ηB/ηA where ηA (lr) is passed in as the optimizer learning rate. Should be ≥1. Should
be set in tandem with the optimizer learning rate (lr); should be larger when the task is more difficult
and the model needs to update its features to learn well. In this case, it helps to make the learning rate
slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates
loraplus_lr_embedding (optional `float`):
If LoRA modules are added to embedding layers your can specify a different learning rate for them. Default
value 1e-6.
kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.

Returns:
`torch.optim.Optimizer`: An instance of the specified optimizer class configured with the model's parameters
organized into groups with custom learning rates.
"""
kallewoof marked this conversation as resolved.
Show resolved Hide resolved

decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}

for name, param in model.named_parameters():
if not param.requires_grad:
continue

module = attrgetter(name)(model)
if isinstance(module, Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param

loraplus_weight_decay = kwargs.pop("loraplus_weight_decay", 0.0)
loraplus_lr_embedding = kwargs.pop("loraplus_lr_embedding", 1e-6)

optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": loraplus_weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": loraplus_weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": loraplus_weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]

optimizer = optimizer_cls(optimizer_grouped_parameters, **kwargs)
eight_bit_names = ["Adam8bit", "AdamW8bit", "PagedAdam8bit", "PagedAdamW8bit"]
if optimizer_cls.__name__ in eight_bit_names:
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
return optimizer
99 changes: 99 additions & 0 deletions tests/test_loraplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import torch
from torch import nn

from peft.import_utils import is_bnb_available
from peft.optimizers import create_loraplus_optimizer

from .testing_utils import require_bitsandbytes


if is_bnb_available():
import bitsandbytes as bnb


class SimpleNet(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.embedding = nn.Embedding(100, 20)
self.layer_norm = nn.LayerNorm(20)
self.lin0 = nn.Linear(20, 20, bias=bias)
self.relu = nn.ReLU()
self.lin1 = nn.Linear(20, 16, bias=bias)

def forward(self, X):
X = self.lin0(self.layer_norm(self.embedding(X)))
X = self.relu(X)
X = self.lin1(X)
return X


@require_bitsandbytes
def test_lora_plus_helper_sucess():
model = SimpleNet()
optimizer_cls = bnb.optim.Adam8bit
lr = 5e-5
optim_config = {
"eps": 1e-6,
"betas": (0.9, 0.999),
"loraplus_weight_decay": 0.0,
}
loraplus_lr_ratio = 1.2
loraplus_lr_embedding = 1e-6
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
lr=lr,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optim_config,
)
assert optim is not None
assert len(optim.param_groups) == 4
assert optim.param_groups[0]["lr"] == lr
assert optim.param_groups[1]["lr"] == loraplus_lr_embedding
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (lr * loraplus_lr_ratio)


@require_bitsandbytes
def test_lora_plus_optimizer_sucess():
"""
Test if the optimizer is correctly created and step function runs without any exception
"""
optimizer_cls = bnb.optim.Adam8bit
optim_config = {
"eps": 1e-6,
"betas": (0.9, 0.999),
"loraplus_weight_decay": 0.0,
}
model: SimpleNet = SimpleNet().cuda()
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
lr=5e-5,
loraplus_lr_ratio=1.2,
loraplus_lr_embedding=1e-6,
**optim_config,
)
loss = torch.nn.CrossEntropyLoss()
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
x = torch.randint(100, (2, 4, 10)).cuda()
output = model(x).permute(0, 3, 1, 2)
label = torch.randint(16, (2, 4, 10)).cuda()
loss_value = loss(output, label)
loss_value.backward()
optim.step()
Loading