Skip to content

Commit

Permalink
Implement weight retrieval method
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 3, 2024
1 parent ec14269 commit 8f96feb
Show file tree
Hide file tree
Showing 13 changed files with 200 additions and 14 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ transformers have been converted:

To use a LoRA transformer, simply replace the model from `candle-transformers` with its counterpart in `candle-lora-transformers`!

## Saving and loading
`candle_lora` supports retrieving weights for LoRA adapters via the `get_tensors` method, defined automatically in `#[auto_layer_convert]`. This function is meant to be used with `candle_core::safetensors::save()`. To load, simply load the `VarBuilder` and pass that to `get_lora_model`.

`candle_lora`'s weight naming is not compatible with `peft` yet.

## Resources
`candle-lora`'s LoRA conversion implementations are based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf), as well as Microsoft's [implementation](https://github.com/microsoft/LoRA).
2 changes: 1 addition & 1 deletion candle-lora-macro/examples/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() {
None,
);

println!("{:?}", model.a);
dbg!(model.get_tensors());

let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap();

Expand Down
45 changes: 45 additions & 0 deletions candle-lora-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
}];);
}

let mut linear_get = TokenStream::new();
if !linear_fields.is_empty() {
quote_into::quote_into!(linear_get += [#{
for (namei,_) in linear_fields.iter() {
quote_into::quote_into!(linear_get += (self.#namei.get_tensors(&mut output)),)
}
}];);
}

let mut conv1d_stream = TokenStream::new();
if !conv1d_fields.is_empty() {
quote_into::quote_into!(conv1d_stream += [#{
Expand All @@ -312,6 +321,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
}];);
}

let mut conv1d_get = TokenStream::new();
if !conv1d_fields.is_empty() {
quote_into::quote_into!(conv1d_get += [#{
for (namei,_) in conv1d_fields.iter() {
quote_into::quote_into!(conv1d_get += (self.#namei.get_tensors(&mut output)),)
}
}];);
}

let mut conv2d_stream = TokenStream::new();
if !conv2d_fields.is_empty() {
quote_into::quote_into!(conv2d_stream += [#{
Expand All @@ -321,6 +339,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
}];);
}

let mut conv2d_get = TokenStream::new();
if !conv2d_fields.is_empty() {
quote_into::quote_into!(conv2d_get += [#{
for (namei,_) in conv2d_fields.iter() {
quote_into::quote_into!(conv2d_get += (self.#namei.get_tensors(&mut output)),)
}
}];);
}

let mut embed_stream = TokenStream::new();
if !embed_fields.is_empty() {
quote_into::quote_into!(embed_stream += [#{
Expand All @@ -330,6 +357,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
}];);
}

let mut embed_get = TokenStream::new();
if !embed_fields.is_empty() {
quote_into::quote_into!(embed_get += [#{
for (namei,_) in embed_fields.iter() {
quote_into::quote_into!(embed_get += (self.#namei.get_tensors(&mut output)),)
}
}];);
}

let mut linear_stream_assign = TokenStream::new();
if !linear_fields.is_empty() {
quote_into::quote_into!(linear_stream_assign += [#{
Expand Down Expand Up @@ -653,6 +689,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 {
#conv2d_merge_option1_stream_assign
#embed_merge_option1_stream_assign
}

pub fn get_tensors(&self) -> ::std::collections::HashMap<String, Tensor> {
let mut output = ::std::collections::HashMap::new();
#linear_get
#conv1d_get
#conv2d_get
#embed_get
output
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions candle-lora-transformers/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_lora::{
EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig,
Saveable,
};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -103,6 +104,12 @@ impl Module for LlamaLinear {
}
}

impl Saveable for LlamaLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!()
}
}

impl LinearLayerLike for LlamaLinear {
fn bias(&self) -> Option<&Tensor> {
self.inner.bias()
Expand Down
2 changes: 1 addition & 1 deletion candle-lora-transformers/src/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! See "Deep Residual Learning for Image Recognition" He et al. 2015
//! <https://arxiv.org/abs/1512.03385>
use candle_core::{Module, Result, D};
use candle_core::{Module, Result, Tensor, D};
use candle_lora::{Conv2dLayerLike, LoraConfig, LoraConv2dConfig};
use candle_lora_macro::{replace_layer_fields, AutoLoraConvert};
use candle_nn::{batch_norm, VarBuilder};
Expand Down
16 changes: 15 additions & 1 deletion candle-lora/src/frozenconv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};

use crate::{Conv1dLayerLike, Conv2dLayerLike};
use crate::{Conv1dLayerLike, Conv2dLayerLike, Saveable};

/// Conv1d, but with a `new` implementation that ensures the weights are detached (frozen).
#[derive(Debug)]
Expand Down Expand Up @@ -42,6 +44,12 @@ impl Module for FrozenConv1d {
}
}

impl Saveable for FrozenConv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for FrozenConv1d {
fn config(&self) -> &Conv1dConfig {
self.conv.config()
Expand Down Expand Up @@ -93,6 +101,12 @@ impl Module for FrozenConv2d {
}
}

impl Saveable for FrozenConv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for FrozenConv2d {
fn config(&self) -> &Conv2dConfig {
self.conv.config()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenembed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Result, Tensor};
use candle_nn::Embedding;

use crate::EmbeddingLayerLike;
use crate::{EmbeddingLayerLike, Saveable};

/// Embedding, but with a `new` implementation that ensures the embeddings are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl crate::Module for FrozenEmbedding {
}
}

impl Saveable for FrozenEmbedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for FrozenEmbedding {
fn embeddings(&self) -> &Tensor {
self.embed.embeddings()
Expand Down
10 changes: 9 additions & 1 deletion candle-lora/src/frozenlinear.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::HashMap;

use candle_core::{Module, Result, Shape, Tensor};
use candle_nn::Linear;

use crate::LinearLayerLike;
use crate::{LinearLayerLike, Saveable};

/// Linear, but with a `new` implementation that ensures the weight and/or biases are detached (frozen).
#[derive(Debug)]
Expand All @@ -27,6 +29,12 @@ impl Module for FrozenLinear {
}
}

impl Saveable for FrozenLinear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for frozen layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for FrozenLinear {
fn bias(&self) -> Option<&Tensor> {
self.linear.bias()
Expand Down
36 changes: 32 additions & 4 deletions candle-lora/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,23 @@ pub struct NewLayers<T: Eq + PartialEq + Hash> {
pub embed: HashMap<T, LoraEmbedding>,
}

pub trait Saveable {
fn get_tensors(&self, accum: &mut HashMap<String, Tensor>);
}

/// Any layer that is linear-like.
pub trait LinearLayerLike: Module + Debug {
pub trait LinearLayerLike: Module + Debug + Saveable {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn shape(&self) -> &Shape;
}

impl Saveable for Linear {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl LinearLayerLike for Linear {
fn weight(&self) -> &Tensor {
self.weight()
Expand All @@ -231,12 +241,18 @@ impl LinearLayerLike for Linear {
}

/// Any layer that is conv1d-like.
pub trait Conv1dLayerLike: Module + Debug {
pub trait Conv1dLayerLike: Module + Debug + Saveable {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv1dConfig;
}

impl Saveable for Conv1d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv1dLayerLike for Conv1d {
fn config(&self) -> &Conv1dConfig {
self.config()
Expand All @@ -250,12 +266,18 @@ impl Conv1dLayerLike for Conv1d {
}

/// Any layer that is conv2d-like.
pub trait Conv2dLayerLike: Module + Debug {
pub trait Conv2dLayerLike: Module + Debug + Saveable {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv2dConfig;
}

impl Saveable for Conv2d {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl Conv2dLayerLike for Conv2d {
fn config(&self) -> &Conv2dConfig {
self.config()
Expand All @@ -269,11 +291,17 @@ impl Conv2dLayerLike for Conv2d {
}

/// Any layer that is embedding-like.
pub trait EmbeddingLayerLike: Module + Debug {
pub trait EmbeddingLayerLike: Module + Debug + Saveable {
fn embeddings(&self) -> &Tensor;
fn hidden_size(&self) -> usize;
}

impl Saveable for Embedding {
fn get_tensors(&self, _accum: &mut HashMap<String, Tensor>) {
unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers.");
}
}

impl EmbeddingLayerLike for Embedding {
fn embeddings(&self) -> &Tensor {
self.embeddings()
Expand Down
20 changes: 19 additions & 1 deletion candle-lora/src/loraconv1d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Mul;
use std::{collections::HashMap, ops::Mul};

use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarBuilder};
Expand All @@ -7,6 +7,7 @@ use trc::Trc;

use crate::{
frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
Saveable,
};

#[derive(Debug, Clone)]
Expand All @@ -17,6 +18,8 @@ pub struct LoraConv1d {
scale: Option<f64>,
dropout: Option<Trc<Dropout>>,
merged: bool,
prefix: String,
id: usize,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -73,6 +76,8 @@ impl LoraConv1d {
},
dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))),
merged: false,
prefix: vb.prefix(),
id,
})
}
}
Expand Down Expand Up @@ -155,6 +160,19 @@ impl Module for LoraConv1d {
}
}

impl Saveable for LoraConv1d {
fn get_tensors(&self, accum: &mut HashMap<String, Tensor>) {
accum.insert(
self.prefix.clone() + &format!("a{}.weight", self.id),
self.a.clone(),
);
accum.insert(
self.prefix.clone() + &format!("b{}.weight", self.id),
self.b.clone(),
);
}
}

impl Conv1dLayerLike for LoraConv1d {
fn config(&self) -> &Conv1dConfig {
self.old.config()
Expand Down
Loading

0 comments on commit 8f96feb

Please sign in to comment.