Skip to content

Commit

Permalink
make lr and loraplus_lr_ratio required forced kw args
Browse files Browse the repository at this point in the history
  • Loading branch information
kallewoof committed Jul 18, 2024
1 parent daddf0b commit f983257
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/peft/optimizers/loraplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def create_loraplus_optimizer(
model: PeftModel, optimizer_cls: type[Optimizer], loraplus_lr_ratio: float, **kwargs
model: PeftModel, optimizer_cls: type[Optimizer], *, lr: float, loraplus_lr_ratio: float, **kwargs
) -> Optimizer:
"""
Creates a LoraPlus optimizer.
Expand All @@ -42,6 +42,7 @@ def create_loraplus_optimizer(
Args:
model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
lr (`float`): The learning rate to be used for the optimizer.
loraplus_lr_ratio (`float`):
The ratio of learning ηB/ηA where ηA (lr) is passed in as the optimizer learning rate. Should be ≥1. Should
be set in tandem with the optimizer learning rate (lr); should be larger when the task is more difficult
Expand Down Expand Up @@ -81,9 +82,8 @@ def create_loraplus_optimizer(
else:
param_groups["groupA"][name] = param

lr = kwargs["lr"]
weight_decay = kwargs.get("weight_decay", 0.0)
loraplus_lr_embedding = kwargs.get("loraplus_lr_embedding", 1e-6)
loraplus_lr_embedding = kwargs.pop("loraplus_lr_embedding", 1e-6)

optimizer_grouped_parameters = [
{
Expand Down
13 changes: 7 additions & 6 deletions tests/test_loraplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def forward(self, X):
def test_lora_plus_helper_sucess():
model = SimpleNet()
optimizer_cls = bnb.optim.Adam8bit
lr = 5e-5
optim_config = {
"lr": 5e-5,
"eps": 1e-6,
"betas": (0.9, 0.999),
"weight_decay": 0.0,
Expand All @@ -57,15 +57,16 @@ def test_lora_plus_helper_sucess():
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optim_config,
lr=lr,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optim_config,
)
assert optim is not None
assert len(optim.param_groups) == 4
assert optim.param_groups[0]["lr"] == optim_config["lr"]
assert optim.param_groups[0]["lr"] == lr
assert optim.param_groups[1]["lr"] == loraplus_lr_embedding
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (optim_config["lr"] * loraplus_lr_ratio)
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (lr * loraplus_lr_ratio)


@require_bitsandbytes
Expand All @@ -75,7 +76,6 @@ def test_lora_plus_optimizer_sucess():
"""
optimizer_cls = bnb.optim.Adam8bit
optim_config = {
"lr": 5e-5,
"eps": 1e-6,
"betas": (0.9, 0.999),
"weight_decay": 0.0,
Expand All @@ -84,9 +84,10 @@ def test_lora_plus_optimizer_sucess():
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optim_config,
lr=5e-5,
loraplus_lr_ratio=1.2,
loraplus_lr_embedding=1e-6,
**optim_config,
)
loss = torch.nn.CrossEntropyLoss()
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
Expand Down

0 comments on commit f983257

Please sign in to comment.