-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating Riemannian Preconditioner #1807
Integrating Riemannian Preconditioner #1807
Conversation
we have added a test file in peft/tests/riemannian_test.py which uses the new optimizer for training a LLM using trainer class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for creating this draft PR ot add Riemannian AdamW. I did a first review but haven't looked at the exact implementation details and compared to the paper yet. I added some comments which, if addressed, will help me better understand what's going on.
Apart from the code comments I added, I have some more general comments:
- This PR contains the code from the lora+ PR. Please remove it.
- Could you please run
make style
? - If some of this code is copied over from https://github.com/pilancilab/Riemannian_Preconditioned_LoRA or elsewhere, please add a comment with a reference.
- You added a test but it does not have the form of a proper unit test. I think it would be better to rewrite this a bit and add it to the
examples/
directory, as it's more akin to an example. - Regarding proper unit tests, check out the tests from the lora+ PR. LMK if you need more guidance.
I know that overall, this seems to be a lot of work, but I'm sure we can get this into a good shape. If you have any questions, don't hesitate to ask.
model (`torch.nn.Module`): The model to be optimized. | ||
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. | ||
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. | ||
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the same indentation and syntax as the other parameters. Also, let's add docs for reg
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, indentation is still wrong. It should be:
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only
src/peft/optimizers/riemannian.py
Outdated
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding | ||
""" | ||
|
||
"""TEST VERSION FOR ADAMW""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For code comments, use #
and not strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/optimizers/riemannian.py
Outdated
""" | ||
|
||
"""TEST VERSION FOR ADAMW""" | ||
assert optimizer_cls.__name__=='AdamW', 'TEST version only supports AdamW optimizer' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use assert
in code (only tests). Here, it is better to raise a TypeError
. Also, I wonder: does the class have to be AdamW
or can it be a subclass? If the latter, you can change the check to: if not issubclass(optimizer_cls, torch.optim.AdamW)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/optimizers/riemannian.py
Outdated
for name, param in model.named_parameters(): | ||
if not param.requires_grad: | ||
continue | ||
# print(name, param.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/optimizers/riemannian.py
Outdated
""" | ||
Creates a Riemmanian optimizer. | ||
Implementation: https://github.com/pilancilab/Riemannian_Preconditioned_LoRA | ||
Reference: https://arxiv.org/pdf/2402.02347 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention that this only works for LoRA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/peft/optimizers/riemannian.py
Outdated
|
||
for group in self.param_groups: | ||
if group['is_lora']: | ||
for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try to understand this: I think we iterate over pairs of lora_A
and lora_B
, which is why we have the zip
and the [::2]
. Is that it?
I wonder if we can make the assumption that pairs of lora_A
and lora_B
are always following consecutively. E.g. what would happen if we have use_dora=True
, could it happen that we now suddenly have triplets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your understanding is correct. This is exactly what I'm concerned/worried about. Since in our paper, for each lora pair (lora_A, lora_B), what we do is to use grad(lora_A)@ inverse(lora_B'lora_B) in place of vanilla grad(lora_A). For our paper's results, we just test and observe this changed gradient is better than vanilla gradient with respect to loss minimization. Moreover, since lora_B'lora_B is of shape r*r, then inverse(lora_B'lora_B) is expected to not take long, especially for small r. Our original implementation is basic and we just iterate like [::2]
.
In its development, I'm not sure how to pair up (lora_A,lora_B) in an error-free way, as you mentioned, for DoRA, since we also have the magnitude term, I feel it's better for us to actually got these pairs by matching the name, i.e., "layer1_attentionq_lora_A" and "layer1_attentionq_lora_B"? This is also better for order keeping since I feel we cannot assume each lora_A is followed by its corresponding lora_B.
Moreover, the [::2]
indeed takes long compared to simple AdamW loop, thus in addition to the inverse operator, we actually also suffer from the loop runtime overhead. Shall we indeed keep some dict for lora_A and lora_B parameters respectively and directly query the corresponding value by index when needed?
src/peft/optimizers/riemannian.py
Outdated
for p1, p2 in list(zip(group["params"],group["params"][1:]))[::2]: | ||
grad = p1.grad | ||
if grad.is_sparse: | ||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") | |
raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients") |
Not sure if it makes sense to suggest SparseAdam here.
src/peft/optimizers/riemannian.py
Outdated
reg_I = self.defaults['reg']*torch.eye(min(p2.shape)).to(p2.device) | ||
scaler = torch.inverse(scaler@scaler.T+reg_I) if p2.shape[0]<p2.shape[1] \ | ||
else torch.inverse(scaler.T@scaler+reg_I) | ||
assert scaler.shape[0]==min(p2.data.shape), 'wrong dimension' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, let's not use assert
but raise a proper error here (ValueError
with a useful message).
src/peft/optimizers/riemannian.py
Outdated
else torch.inverse(scaler.T@scaler+reg_I) | ||
assert scaler.shape[0]==min(p2.data.shape), 'wrong dimension' | ||
except: | ||
print('invalid condition') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
src/peft/optimizers/riemannian.py
Outdated
if group["weight_decay"] > 0.0: | ||
p2.add_(p2, alpha=(-group["lr"] * group["weight_decay"])) | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this code path normal AdamW or are there changes in here too? Adding a comment would be helpful.
@fangzhaozhang do you still plan on working on this? |
Yes, I'm going to implement the unit test this weekend. Sorry about the
delay since I'm recently on some other research work.
…On Thu, Jun 27, 2024 at 1:58 AM Benjamin Bossan ***@***.***> wrote:
@fangzhaozhang <https://github.com/fangzhaozhang> do you still plan on
working on this?
—
Reply to this email directly, view it on GitHub
<#1807 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AROEEJXZDDHZGOWRML52DF3ZJPHZZAVCNFSM6AAAAABINZWCHWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCOJUGE2TGMZUG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Fangzhao Zhang
250-899-2965
|
I'm back on the implementation. Thanks so much for your detailed comments. With respect to the general points,
I've also fixed small issues such as code comments, function name, etc. as suggested in the comments above. However, I'm not very sure about the following point:
Would be glad to hear from your feedback/suggestions on the above questions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the updates. We're getting closer but there are still a few areas that need to be improved.
Also, note that the LoRA+ PR is now moved to #1915 with a few changes.
Thus I'm not sure whether it's best to make our method appear in peft/optimizers in parallel with lora plus, it feels more natural to get our optimizer in parallel with AdamW implementation or just pass in a parameter like lora=True to transformer's AdamW in order to switch to our method
Since this is very PEFT specific, I think the best fit is indeed here. It would be quite hard to convince transformers to add this very specific change.
2. Besides, our method is not directly applicable to bitsandbytes and other quantized form since torch.inverse() is only compliant with certain dtype. Then shall we also do a dtype conversion before and after we compute torch.inverse() to make it more general?
If you can implement a version that works with quantized weights, that would be great. If not, that's also okay, but then let's document this clearly.
# flake8: noqa | ||
# There's no way to ignore "F401 '...' imported but unused" warnings in this | ||
# module, but to preserve other warnings. So, don't check this module at all | ||
|
||
# coding=utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines can be removed. At the bottom of the file, add __all__ = ["create_riemannian_optimizer"]
# module, but to preserve other warnings. So, don't check this module at all | ||
|
||
# coding=utf-8 | ||
# Copyright 2023-present the HuggingFace Inc. team. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023-present the HuggingFace Inc. team. | |
# Copyright 2024-present the HuggingFace Inc. team. |
model (`torch.nn.Module`): The model to be optimized. | ||
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. | ||
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. | ||
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, indentation is still wrong. It should be:
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only
if not issubclass(optimizer_cls, torch.optim.AdamW): | ||
raise TypeError("TEST version only supports AdamW optimizer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the optimizer_cls
argument is not actually except to raise an error, how about removing it completely?
def create_riemannian_optimizer( | ||
model: PeftModel, | ||
optimizer_cls: type[Optimizer], | ||
optimizer_kwargs: dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you probably took this from the LoRA+ PR, let me refer to the comment I put there:
A suggestion: Let's remove optimizer_kwargs
and just add **kwargs
. IMO, that makes calling this function easier, as we can use create_riemannian_optimizer(..., weight_decay=1e-3)
instead of create_riemannian_optimizer(..., optimizer_kwargs={..., "weight_decay": 1e-3})
. And since lr
is not optional, let's make this a normal arg of create_riemannian_optimizer
.
for group in self.param_groups: | ||
if group["is_lora"]: | ||
for p1, p2 in list(zip(group["params"], group["params"][1:]))[::2]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in the other comment, this is indeed error prone. For this, the logic here:
should be improved. I think it's better if we create two separate groups for lora_A
and lora_B
. After the loop there, let's also check that both groups have the same length and that the length is > 0. In the optimizer_grouped_parameters
, we can set "is_lora_A": True
and "is_lora_B": True
accordingly.
After making this change, the line here could be simplified to:
# this works because there is exactly one lora_A and one lora_B group
lora_A_params = next(group for group in self.param_groups if group["is_lora_A"])
lora_B_params = next(group for group in self.param_groups if group["is_lora_B"])
for p1, p2 in zip(lora_A_params, lora_B_params):
if p2.shape[0] < p2.shape[1] | ||
else torch.inverse(scaler.T @ scaler + reg_I) | ||
) | ||
assert scaler.shape[0] == min(p2.data.shape), "wrong dimension" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use assert, instead raise a proper ValueError
with a helpful message.
if p1.shape[0] < p1.shape[1] | ||
else torch.inverse(scaler.T @ scaler + reg_I) | ||
) | ||
assert scaler.shape[0] == min(p1.data.shape), "wrong dimension" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use assert, instead raise a proper ValueError with a helpful message.
else torch.inverse(scaler.T @ scaler + reg_I) | ||
) | ||
assert scaler.shape[0] == min(p2.data.shape), "wrong dimension" | ||
except RuntimeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None
. Let's completely avoid printing messages.
) | ||
assert scaler.shape[0] == min(p1.data.shape), "wrong dimension" | ||
except RuntimeError: | ||
print("invalid condition") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None
. Let's completely avoid printing messages.
Cool! We should ensure that we add documentation clarifying whether this works together with LoRA+ or whether the two are mutually exclusive for some reason. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
@fangzhaozhang Do you still plan on finishing this PR? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Paper link: https://arxiv.org/pdf/2402.02347
This is an attempt to integrate a special optimizer for LoRA training to current huggingface peft codebase. We follow structure in PR to add LoRA+ (#1509).