Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any way to save lora-converted model? #12

Closed
Adamska1008 opened this issue Apr 3, 2024 · 5 comments
Closed

Is there any way to save lora-converted model? #12

Adamska1008 opened this issue Apr 3, 2024 · 5 comments
Assignees

Comments

@Adamska1008
Copy link

I tried to fine tune TinyLlama with this crate. I use candle-lora/candle-lora-transformers/examples/llama.rs to load model.safetensors, do stuff about training, eventually find that there's no way to save the model in safetensors format.

I tried to implement a save method myself wrapping candle_core::safetensors::save(), but how can I get the weight of lora part? All I can get is the raw model before it converted to lora model.

For example, if you run /candle-lora-macro/examples/linear.rs, by println!("{:?}", model.a); you will see it printed as Linear struct, not a LoraLinear struct, and you can't get ff_aff_b from model.a, despite that the model is converted to a lora model.

@EricLBuehler
Copy link
Owner

EricLBuehler commented Apr 3, 2024

This is implemented/fixed in #13 which has been merged. Please note that the weight naming is incompatible with peft at the moment. If this is a problem, please feel free to raise an issue and I will fix it

@Adamska1008
Copy link
Author

This is implemented/fixed in #13 which has been merged. Please note that the weight naming is incompatible with peft at the moment. If this is a problem, please feel free to raise an issue and I will fix it

Thank you very much! I tried this and get a 536KB safetensors file with header:

{"lora_llamaa0.weight":{"data_offsets":[0,512000],"dtype":"F16","shape":[8,32000]},"lora_llamab0.weight":{"data_offsets":[512000,544768],"dtype":"F16","shape":[2048,8]}}

Is it as expected? I also want to know how to apply the Lora tensors after loading a VarBuilder from original model.

@EricLBuehler
Copy link
Owner

No, the prefix was incorrect but it should be fixed now. To load the Lora tensors, pass get_lora_model the VarBuilder returned by from_mmaped_safetensors. Here is an example of loading the VarBuilder:

let vb = from_mmaped_safetensors(&filenames, dtype, &device, false)?;

That vb is then passed to get_lora_model:

if merge {
this.get_merged_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
} else {
this.get_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
}

@Adamska1008
Copy link
Author

No, the prefix was incorrect but it should be fixed now. To load the Lora tensors, pass get_lora_model the VarBuilder returned by from_mmaped_safetensors. Here is an example of loading the VarBuilder:

let vb = from_mmaped_safetensors(&filenames, dtype, &device, false)?;

That vb is then passed to get_lora_model:

if merge {
this.get_merged_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
} else {
this.get_lora_model(
lora_config,
&vb.pp("lora_llama"),
Some(linear_config),
None,
None,
Some(embed_config),
)
}

Really helpful, thanks again!

@EricLBuehler
Copy link
Owner

Glad to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants