-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generating LoRA parameters with hypernetworks #12
Comments
Yeah so basically I'd try using lorax to lorafy the parameters of your base transformer. Take a look at them in the python interpreter, and figure out where the A/B values you want to generate are. Then, during your forward pass just generate the As and Bs and populate the tree with them. If you give me some minimal example code I can give you more specific pointers. |
Thanks for your help! Here is a minimal example code of the transformer module I use:
The transformer module is stacked for multiple layers to get the final model like below:
My thought is to initialize the hypernetwork parameters inside class |
I think the correct implementation depends a lot on what you want the hypernetwork to take as input. Are you going to compute all the lora weights at once, before beginning execution of your model? If so, I'd recommend making the hypernetwork a separate model (or at least making a new flax module which holds both the transformer and the hypernetwork). The reason is that it's actually a little bit annoying to interact with parameters from within the context of a flax model, since they make everything appear to be OOP-y. You'd probably have to look into using |
Hi,
Thanks for open-sourcing this brilliant package! Similar to #6 , I also want to apply hypernetworks (HN) to LoRA, but my setting is a bit different. I have a base transformer model, and I want to use LoRA to adapt it to multiple different tasks. Instead of learning a separate adapter for each task, I want to use a single HN to generate different LoRA parameters for each task, by conditioning the HN on some task context that is different for each task. During the fine-tuning process, only the parameters of the HN are updated, while the base transformer is frozen. I was wondering that if you can give some suggestions or examples on how to implement this HN idea in
lorax
? Thanks a lot for your help!The text was updated successfully, but these errors were encountered: