-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add lora+ implentation #1509
Add lora+ implentation #1509
Conversation
Duplicate of #1504 :) Sorry about closing (wrong button). |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What was the conclusion in that issue? |
No conclusion yet, we want to wait and see if the performance gains are indeed robust. Regarding your code, it's basically just a giant string with the code, right? Was that the intent? |
waiting for you to ask implement new Trainer object or not? |
Hey, after some discussion, I think we can proceed with this project. Let's add the Some considerations:
|
@moghadas76 do you still plan on working on this? |
Yes, This weekend I'll fix the points
…On Tue, Mar 12, 2024, 1:41 PM Benjamin Bossan ***@***.***> wrote:
@moghadas76 <https://github.com/moghadas76> do you still plan on working
on this?
—
Reply to this email directly, view it on GitHub
<#1509 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFRH3KKIYBNPIB4W5RLD5PTYX3ZZPAVCNFSM6AAAAABD2OVEOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOJRGU3DGMZQHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Great, thanks. On top of what I mentioned, let's also move this to a new file. I'm thinking |
Please review my code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. It is a good start but there are a few issues, please check my comments. On top of that, could you please move the function out of helpers.py
into a separate module, as I mentioned above?
Great, thanks. On top of what I mentioned, let's also move this to a new file. I'm thinking src/peft/optimizers/loraplus.py. The idea here is that we want to add more optimizer-related methods in the future, so it makes sense to choose a proper file structure right away.
Moreover, it would be great to document this function in our PEFT docs, but it would be fine to do that in a follow-up PR.
Finally, please run make style
on your changes.
@moghadas76 Do you still plan on working on this? |
Yes, I'll fix the comments tonight
…On Mon, Mar 25, 2024, 12:48 PM Benjamin Bossan ***@***.***> wrote:
@moghadas76 <https://github.com/moghadas76> Do you still plan on working
on this?
—
Reply to this email directly, view it on GitHub
<#1509 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFRH3KIOTROBBMBYUMYCHE3Y2AFIPAVCNFSM6AAAAABD2OVEOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMJXHAZDIOBVGE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Thanks. No need to rush, I just wanted to inquire if you're still on it :) |
@moghadas76 LMK once you're finished with your changes and want me to do another review. |
Gentle ping @moghadas76 |
Hi |
Sorry for the delay, I was at a conference, will review soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the adjustments, this already looks quite good. I still found a few minor areas for improvements, which I commented. Also, as mentioned in my earlier comment, could you please move the code to a different file?
Hmm, code quality checks are still failing with:
Is it possible that your local ruff version differs? CI uses v0.2.2. |
You were right. My ruff version was old. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. Our code style check still fails though, not sure what the reason is if you use the same ruff version. Here is the diff that I get when running ruff locally on your branch:
modified src/peft/optimizers/__init__.py
@@ -17,4 +17,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .loraplus import create_loraplus_optimizer
\ No newline at end of file
+from .loraplus import create_loraplus_optimizer
modified src/peft/optimizers/loraplus.py
@@ -8,20 +8,24 @@ from transformers.trainer_pt_utils import get_parameter_names
from ..peft_model import PeftModel
-def create_loraplus_optimizer(model: PeftModel, optimizer_cls: type[Optimizer], optimizer_kwargs: dict, loraplus_lr_embedding: float=1e-6) -> Optimizer:
+def create_loraplus_optimizer(
+ model: PeftModel, optimizer_cls: type[Optimizer], optimizer_kwargs: dict, loraplus_lr_embedding: float = 1e-6
+) -> Optimizer:
"""
- Creates a LoraPlus optimizer.
- Implementing LoRA+ https://arxiv.org/abs/2402.12354
- Reference: https://github.com/nikhil-ghosh-berkeley/loraplus/
+ Creates a LoraPlus optimizer. Implementing LoRA+ https://arxiv.org/abs/2402.12354 Reference:
+ https://github.com/nikhil-ghosh-berkeley/loraplus/
Args:
model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
- - **loraplus_lr_ratio** (`float`): The ratio of the learning rate to be used for the embedding layer. Defaults to loraplus_lr_ratio
- - loraplus_lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to loraplus_lr_embedding
+ - **loraplus_lr_ratio** (`float`): The ratio of the learning rate to be used for the embedding layer.
+ Defaults to loraplus_lr_ratio
+ - loraplus_lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to
+ loraplus_lr_embedding
"""
from ..tuners.lora.layer import Embedding
+
loraplus_lr_ratio = optimizer_kwargs.pop("loraplus_lr_ratio")
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
@@ -81,6 +85,7 @@ def create_loraplus_optimizer(model: PeftModel, optimizer_cls: type[Optimizer],
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in model.modules():
if isinstance(module, nn.Embedding):
modified tests/test_loraplus_helper.py
@@ -25,32 +25,37 @@ def test_lora_plus_helper_sucess():
model = SimpleNet()
optimizer_cls = bnb.optim.Adam8bit
optim_config = {
- 'lr': 5e-5,
- 'eps': 1e-6,
- 'betas': (0.9, 0.999),
- 'weight_decay': 0.0,
+ "lr": 5e-5,
+ "eps": 1e-6,
+ "betas": (0.9, 0.999),
+ "weight_decay": 0.0,
"loraplus_lr_ratio": 0.2,
}
- optim = create_loraplus_optimizer(model=model, optimizer_cls=optimizer_cls, optimizer_kwargs=optim_config, loraplus_lr_embedding=1e-6)
+ optim = create_loraplus_optimizer(
+ model=model, optimizer_cls=optimizer_cls, optimizer_kwargs=optim_config, loraplus_lr_embedding=1e-6
+ )
assert optim is not None
assert len(optim.param_groups) == 4
+
def test_lora_plus_optimizer_sucess():
optimizer_cls = bnb.optim.Adam8bit
optim_config = {
- 'lr': 5e-5,
- 'eps': 1e-6,
- 'betas': (0.9, 0.999),
- 'weight_decay': 0.0,
+ "lr": 5e-5,
+ "eps": 1e-6,
+ "betas": (0.9, 0.999),
+ "weight_decay": 0.0,
"loraplus_lr_ratio": 0.2,
}
model: SimpleNet = SimpleNet().cuda()
- optim = create_loraplus_optimizer(model=model, optimizer_cls=optimizer_cls, optimizer_kwargs=optim_config, loraplus_lr_embedding=1e-6)
+ optim = create_loraplus_optimizer(
+ model=model, optimizer_cls=optimizer_cls, optimizer_kwargs=optim_config, loraplus_lr_embedding=1e-6
+ )
loss = torch.nn.CrossEntropyLoss()
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
x = torch.randint(100, (2, 4, 10)).cuda()
output = model(x).permute(0, 3, 1, 2)
- label = torch.randint(16, (2,4,10,)).cuda()
+ label = torch.randint(16, (2, 4, 10)).cuda()
loss_value = loss(output, label)
loss_value.backward()
optim.step()
Could you determine what is the problem of Traceback (most recent call last): |
Where did you see that? The code quality check only spits out this:
Also, these lines could probably be reduced to a single line: I guess what happened here is that your editor added those line breaks because it is configured with a lower line limit than what we have in PEFT. |
@moghadas76 @BenjaminBossan I can take this up and make necessary changes if you are short on time. We can aim to get this PR merged this week, let me know if its okay. |
That's fine from my point of view. A separate PR with credits given would also work for me. For my understanding: Who is "we" in this case, are you collaborating with moghadas76? |
I am doing it by myself. By "we" I just meant you and me, and @moghadas76 if they are available. Also can you advice on how to provide the credit? |
I see. Sure, please go ahead. As you probably can't push on top of this PR, feel free to create a new one. If we don't hear back from moghadas76 by the time the new PR is ready to be merged, we can add them as a co-author. |
I am happy to clean this up too. IMO the API is not the most clear as currently presented. IMO the embedding LR and the ratio should be either both optimizer_kwargs or both named args. It makes more sense to me that the Finally, should the 8 bit -> 32 bit upcast be applied to all the 8 bit optimizers?
|
@stillmatic Thanks, that would also be fine, just pinging @shubhamjain0594 to ensure that there won't be any duplicate work. |
This is very disrespectful. He stole this branch. |
@BenjaminBossan if @stillmatic has time then sure please go for it. I was going to raise a PR today, but was mainly looking to fix some documentation and other small bugs you had raised. @moghadas76 not stealing anyone's work here. Just want to get this PR merged so that I can start using it in my repo without doing weird installation. I have not yet raised a PR, and can wait if you have time to get this done. |
Happy to make my suggestions as comments on this branch if you have the time to address them here. I appreciate the work - I used the implementation here in my training, but ran into some problems, hence seeing what needs improvement. |
Assuming proper credit is given, there's nothing disrespectful about picking up someone's work if they are unable to complete it in a timely manner; there are month+ gaps between you receiving review and you actually addressing it. The PR is |
I have rebased and done some fixes on top of this pull request here: https://github.com/kallewoof/peft/tree/202407-loraplus @moghadas76 You can either base your work off of my fixes or redo it yourself. Whatever gets this merged the fastest. I specifically did not make a pull request out of this, as it sounds like you really want to do this yourself. @stillmatic Did I get your suggestions in there correctly? |
I agree that this is not about "stealing" work. All of this is a big collaboration, after all the initial PR was heavily based on https://github.com/nikhil-ghosh-berkeley/loraplus/blob/main/lora_plus.py. It would be best if we can get this PR over the finish line, as we're not missing a lot. @moghadas76 if you are still interested, let's try to finish this in the next two weeks. There has been some good feedback in this thread, so I'm sure we have everything we need to get this ready. If there is no progress here, I'm happy to merge other PRs that implement the same idea, with proper references being given. I'll ensure that co-authorship is respected when merging. |
] | ||
|
||
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | ||
if optimizer_cls.__name__ == "Adam8bit": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this support the other 8-bit adam implementations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kallewoof this is the only important comment I had, simplest would be
eight_bit_names = ["Adam8bit", "AdamW8bit", "PagedAdam8bit", "PagedAdamW8bit"]
if optimizer_cls.__name__ in eight_bit_names:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Added to proposed branch.
""" | ||
from ..tuners.lora.layer import Embedding | ||
|
||
loraplus_lr_ratio = optimizer_kwargs.pop("loraplus_lr_ratio") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's confusing to me that loraplus_lr_ratio
is an optimizer_kwarg
while loraplus_lr_embedding
is a function argument. IMO both should be function arguments, while optimizer_kwarg
should reflect the arguments passed to the optimizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made both of them args, outside of optimizer_kwarg
in https://github.com/kallewoof/peft/tree/202407-loraplus FWIW, based on your comment.
Args: | ||
model (`torch.nn.Module`): The model to be optimized. | ||
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. | ||
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should note explicitly that lr
and weight_decay
are expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in the proposed changes in https://github.com/kallewoof/peft/tree/202407-loraplus
It's been a week, so I opened the above branch as a pull req. Close it if that's not OK, @BenjaminBossan. |
Thanks @kallewoof. Starting tomorrow until the end of the week, I'll be at EuroPython Prague, so I will have little time for reviews etc. If by then, there is no progress on this PR, we can close it and continue with yours. As mentioned earlier, I'll make sure to assign proper credit before merging. |
Supersedes by #1915. |
Implementing LoRA+ https://arxiv.org/abs/2402.12354