-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LoRA+: Efficient Low Rank Adaptation of Large Models https://arxiv.org/abs/2402.12354 Call create_loraplus_optimizer to initialize an optimizer with optimizer parameters that are especially effective for LoRA training. Builds upon this code base: https://github.com/nikhil-ghosh-berkeley/loraplus --------- Co-authored-by: moghadas76 <s.m.moghadas2012@gmail.com> Co-authored-by: Chris Hua <stillmatic@users.noreply.github.com>
- Loading branch information
1 parent
296fbcd
commit 273acf0
Showing
5 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .loraplus import create_loraplus_optimizer | ||
|
||
|
||
__all__ = ["create_loraplus_optimizer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This module contains the implementation of the LoraPlus optimizer. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from operator import attrgetter | ||
|
||
import torch.nn as nn | ||
from torch.optim import Optimizer | ||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS | ||
from transformers.trainer_pt_utils import get_parameter_names | ||
|
||
from ..peft_model import PeftModel | ||
from ..tuners.lora.layer import Embedding | ||
|
||
|
||
def create_loraplus_optimizer( | ||
model: PeftModel, optimizer_cls: type[Optimizer], *, lr: float, loraplus_lr_ratio: float, **kwargs | ||
) -> Optimizer: | ||
""" | ||
Creates a LoraPlus optimizer. | ||
Efficient Low Rank Adaptation of Large Models: https://arxiv.org/abs/2402.12354 | ||
Reference: https://github.com/nikhil-ghosh-berkeley/loraplus/ | ||
Args: | ||
model (`torch.nn.Module`): The model to be optimized. | ||
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. | ||
lr (`float`): The learning rate to be used for the optimizer. | ||
loraplus_lr_ratio (`float`): | ||
The ratio of learning ηB/ηA where ηA (lr) is passed in as the optimizer learning rate. Should be ≥1. Should | ||
be set in tandem with the optimizer learning rate (lr); should be larger when the task is more difficult | ||
and the model needs to update its features to learn well. In this case, it helps to make the learning rate | ||
slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates | ||
loraplus_lr_embedding (optional `float`): | ||
If LoRA modules are added to embedding layers your can specify a different learning rate for them. Default | ||
value 1e-6. | ||
kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. | ||
Returns: | ||
`torch.optim.Optimizer`: An instance of the specified optimizer class configured with the model's parameters | ||
organized into groups with custom learning rates. | ||
""" | ||
|
||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) | ||
decay_parameters = [name for name in decay_parameters if "bias" not in name] | ||
param_groups = { | ||
"groupA": {}, | ||
"groupB": {}, | ||
"groupB_no_decay": {}, | ||
"embedding": {}, | ||
} | ||
|
||
for name, param in model.named_parameters(): | ||
if not param.requires_grad: | ||
continue | ||
|
||
module = attrgetter(name)(model) | ||
if isinstance(module, Embedding): | ||
param_groups["embedding"][name] = param | ||
elif "lora_B" in name or param.ndim == 1: | ||
if name in decay_parameters: | ||
param_groups["groupB"][name] = param | ||
else: | ||
param_groups["groupB_no_decay"][name] = param | ||
else: | ||
param_groups["groupA"][name] = param | ||
|
||
loraplus_weight_decay = kwargs.pop("loraplus_weight_decay", 0.0) | ||
loraplus_lr_embedding = kwargs.pop("loraplus_lr_embedding", 1e-6) | ||
|
||
optimizer_grouped_parameters = [ | ||
{ | ||
"params": list(param_groups["groupA"].values()), | ||
"weight_decay": loraplus_weight_decay, | ||
"lr": lr, | ||
}, | ||
{ | ||
"params": list(param_groups["embedding"].values()), | ||
"weight_decay": loraplus_weight_decay, | ||
"lr": loraplus_lr_embedding, | ||
}, | ||
{ | ||
"params": list(param_groups["groupB"].values()), | ||
"weight_decay": loraplus_weight_decay, | ||
"lr": lr * loraplus_lr_ratio, | ||
}, | ||
{ | ||
"params": list(param_groups["groupB_no_decay"].values()), | ||
"weight_decay": 0.0, | ||
"lr": lr * loraplus_lr_ratio, | ||
}, | ||
] | ||
|
||
optimizer = optimizer_cls(optimizer_grouped_parameters, **kwargs) | ||
eight_bit_names = ["Adam8bit", "AdamW8bit", "PagedAdam8bit", "PagedAdamW8bit"] | ||
if optimizer_cls.__name__ in eight_bit_names: | ||
import bitsandbytes | ||
|
||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | ||
for module in model.modules(): | ||
if isinstance(module, nn.Embedding): | ||
manager.register_module_override(module, "weight", {"optim_bits": 32}) | ||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from peft.import_utils import is_bnb_available | ||
from peft.optimizers import create_loraplus_optimizer | ||
|
||
from .testing_utils import require_bitsandbytes | ||
|
||
|
||
if is_bnb_available(): | ||
import bitsandbytes as bnb | ||
|
||
|
||
class SimpleNet(nn.Module): | ||
def __init__(self, bias=True): | ||
super().__init__() | ||
self.embedding = nn.Embedding(100, 20) | ||
self.layer_norm = nn.LayerNorm(20) | ||
self.lin0 = nn.Linear(20, 20, bias=bias) | ||
self.relu = nn.ReLU() | ||
self.lin1 = nn.Linear(20, 16, bias=bias) | ||
|
||
def forward(self, X): | ||
X = self.lin0(self.layer_norm(self.embedding(X))) | ||
X = self.relu(X) | ||
X = self.lin1(X) | ||
return X | ||
|
||
|
||
@require_bitsandbytes | ||
def test_lora_plus_helper_sucess(): | ||
model = SimpleNet() | ||
optimizer_cls = bnb.optim.Adam8bit | ||
lr = 5e-5 | ||
optim_config = { | ||
"eps": 1e-6, | ||
"betas": (0.9, 0.999), | ||
"loraplus_weight_decay": 0.0, | ||
} | ||
loraplus_lr_ratio = 1.2 | ||
loraplus_lr_embedding = 1e-6 | ||
optim = create_loraplus_optimizer( | ||
model=model, | ||
optimizer_cls=optimizer_cls, | ||
lr=lr, | ||
loraplus_lr_ratio=loraplus_lr_ratio, | ||
loraplus_lr_embedding=loraplus_lr_embedding, | ||
**optim_config, | ||
) | ||
assert optim is not None | ||
assert len(optim.param_groups) == 4 | ||
assert optim.param_groups[0]["lr"] == lr | ||
assert optim.param_groups[1]["lr"] == loraplus_lr_embedding | ||
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (lr * loraplus_lr_ratio) | ||
|
||
|
||
@require_bitsandbytes | ||
def test_lora_plus_optimizer_sucess(): | ||
""" | ||
Test if the optimizer is correctly created and step function runs without any exception | ||
""" | ||
optimizer_cls = bnb.optim.Adam8bit | ||
optim_config = { | ||
"eps": 1e-6, | ||
"betas": (0.9, 0.999), | ||
"loraplus_weight_decay": 0.0, | ||
} | ||
model: SimpleNet = SimpleNet().cuda() | ||
optim = create_loraplus_optimizer( | ||
model=model, | ||
optimizer_cls=optimizer_cls, | ||
lr=5e-5, | ||
loraplus_lr_ratio=1.2, | ||
loraplus_lr_embedding=1e-6, | ||
**optim_config, | ||
) | ||
loss = torch.nn.CrossEntropyLoss() | ||
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters()) | ||
x = torch.randint(100, (2, 4, 10)).cuda() | ||
output = model(x).permute(0, 3, 1, 2) | ||
label = torch.randint(16, (2, 4, 10)).cuda() | ||
loss_value = loss(output, label) | ||
loss_value.backward() | ||
optim.step() |