From 8f96feb60116e5827c3a1c5b9926b8d5344920ed Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 3 Apr 2024 06:57:45 -0400 Subject: [PATCH] Implement weight retrieval method --- README.md | 5 +++ candle-lora-macro/examples/linear.rs | 2 +- candle-lora-macro/src/lib.rs | 45 ++++++++++++++++++++++++++ candle-lora-transformers/src/llama.rs | 7 ++++ candle-lora-transformers/src/resnet.rs | 2 +- candle-lora/src/frozenconv.rs | 16 ++++++++- candle-lora/src/frozenembed.rs | 10 +++++- candle-lora/src/frozenlinear.rs | 10 +++++- candle-lora/src/lib.rs | 36 ++++++++++++++++++--- candle-lora/src/loraconv1d.rs | 20 +++++++++++- candle-lora/src/loraconv2d.rs | 20 +++++++++++- candle-lora/src/loraembed.rs | 21 ++++++++++-- candle-lora/src/loralinear.rs | 20 +++++++++++- 13 files changed, 200 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index d977b8c..762e1b7 100644 --- a/README.md +++ b/README.md @@ -54,5 +54,10 @@ transformers have been converted: To use a LoRA transformer, simply replace the model from `candle-transformers` with its counterpart in `candle-lora-transformers`! +## Saving and loading +`candle_lora` supports retrieving weights for LoRA adapters via the `get_tensors` method, defined automatically in `#[auto_layer_convert]`. This function is meant to be used with `candle_core::safetensors::save()`. To load, simply load the `VarBuilder` and pass that to `get_lora_model`. + +`candle_lora`'s weight naming is not compatible with `peft` yet. + ## Resources `candle-lora`'s LoRA conversion implementations are based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf), as well as Microsoft's [implementation](https://github.com/microsoft/LoRA). \ No newline at end of file diff --git a/candle-lora-macro/examples/linear.rs b/candle-lora-macro/examples/linear.rs index 03fe91c..58a2cc3 100644 --- a/candle-lora-macro/examples/linear.rs +++ b/candle-lora-macro/examples/linear.rs @@ -51,7 +51,7 @@ fn main() { None, ); - println!("{:?}", model.a); + dbg!(model.get_tensors()); let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap(); diff --git a/candle-lora-macro/src/lib.rs b/candle-lora-macro/src/lib.rs index 40e47a4..2f9b6fe 100644 --- a/candle-lora-macro/src/lib.rs +++ b/candle-lora-macro/src/lib.rs @@ -303,6 +303,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut linear_get = TokenStream::new(); + if !linear_fields.is_empty() { + quote_into::quote_into!(linear_get += [#{ + for (namei,_) in linear_fields.iter() { + quote_into::quote_into!(linear_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut conv1d_stream = TokenStream::new(); if !conv1d_fields.is_empty() { quote_into::quote_into!(conv1d_stream += [#{ @@ -312,6 +321,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut conv1d_get = TokenStream::new(); + if !conv1d_fields.is_empty() { + quote_into::quote_into!(conv1d_get += [#{ + for (namei,_) in conv1d_fields.iter() { + quote_into::quote_into!(conv1d_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut conv2d_stream = TokenStream::new(); if !conv2d_fields.is_empty() { quote_into::quote_into!(conv2d_stream += [#{ @@ -321,6 +339,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut conv2d_get = TokenStream::new(); + if !conv2d_fields.is_empty() { + quote_into::quote_into!(conv2d_get += [#{ + for (namei,_) in conv2d_fields.iter() { + quote_into::quote_into!(conv2d_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut embed_stream = TokenStream::new(); if !embed_fields.is_empty() { quote_into::quote_into!(embed_stream += [#{ @@ -330,6 +357,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut embed_get = TokenStream::new(); + if !embed_fields.is_empty() { + quote_into::quote_into!(embed_get += [#{ + for (namei,_) in embed_fields.iter() { + quote_into::quote_into!(embed_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut linear_stream_assign = TokenStream::new(); if !linear_fields.is_empty() { quote_into::quote_into!(linear_stream_assign += [#{ @@ -653,6 +689,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { #conv2d_merge_option1_stream_assign #embed_merge_option1_stream_assign } + + pub fn get_tensors(&self) -> ::std::collections::HashMap { + let mut output = ::std::collections::HashMap::new(); + #linear_get + #conv1d_get + #conv2d_get + #embed_get + output + } } } diff --git a/candle-lora-transformers/src/llama.rs b/candle-lora-transformers/src/llama.rs index c3c218e..a8c13c5 100644 --- a/candle-lora-transformers/src/llama.rs +++ b/candle-lora-transformers/src/llama.rs @@ -3,6 +3,7 @@ use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_lora::{ EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig, + Saveable, }; use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -103,6 +104,12 @@ impl Module for LlamaLinear { } } +impl Saveable for LlamaLinear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!() + } +} + impl LinearLayerLike for LlamaLinear { fn bias(&self) -> Option<&Tensor> { self.inner.bias() diff --git a/candle-lora-transformers/src/resnet.rs b/candle-lora-transformers/src/resnet.rs index 434e5b0..5602a5c 100644 --- a/candle-lora-transformers/src/resnet.rs +++ b/candle-lora-transformers/src/resnet.rs @@ -3,7 +3,7 @@ //! See "Deep Residual Learning for Image Recognition" He et al. 2015 //! -use candle_core::{Module, Result, D}; +use candle_core::{Module, Result, Tensor, D}; use candle_lora::{Conv2dLayerLike, LoraConfig, LoraConv2dConfig}; use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; use candle_nn::{batch_norm, VarBuilder}; diff --git a/candle-lora/src/frozenconv.rs b/candle-lora/src/frozenconv.rs index e1d4919..b4e528f 100644 --- a/candle-lora/src/frozenconv.rs +++ b/candle-lora/src/frozenconv.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Module, Result, Tensor}; use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; -use crate::{Conv1dLayerLike, Conv2dLayerLike}; +use crate::{Conv1dLayerLike, Conv2dLayerLike, Saveable}; /// Conv1d, but with a `new` implementation that ensures the weights are detached (frozen). #[derive(Debug)] @@ -42,6 +44,12 @@ impl Module for FrozenConv1d { } } +impl Saveable for FrozenConv1d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl Conv1dLayerLike for FrozenConv1d { fn config(&self) -> &Conv1dConfig { self.conv.config() @@ -93,6 +101,12 @@ impl Module for FrozenConv2d { } } +impl Saveable for FrozenConv2d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl Conv2dLayerLike for FrozenConv2d { fn config(&self) -> &Conv2dConfig { self.conv.config() diff --git a/candle-lora/src/frozenembed.rs b/candle-lora/src/frozenembed.rs index 9c77412..0341c30 100644 --- a/candle-lora/src/frozenembed.rs +++ b/candle-lora/src/frozenembed.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Result, Tensor}; use candle_nn::Embedding; -use crate::EmbeddingLayerLike; +use crate::{EmbeddingLayerLike, Saveable}; /// Embedding, but with a `new` implementation that ensures the embeddings are detached (frozen). #[derive(Debug)] @@ -27,6 +29,12 @@ impl crate::Module for FrozenEmbedding { } } +impl Saveable for FrozenEmbedding { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl EmbeddingLayerLike for FrozenEmbedding { fn embeddings(&self) -> &Tensor { self.embed.embeddings() diff --git a/candle-lora/src/frozenlinear.rs b/candle-lora/src/frozenlinear.rs index b241688..a6b869a 100644 --- a/candle-lora/src/frozenlinear.rs +++ b/candle-lora/src/frozenlinear.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Module, Result, Shape, Tensor}; use candle_nn::Linear; -use crate::LinearLayerLike; +use crate::{LinearLayerLike, Saveable}; /// Linear, but with a `new` implementation that ensures the weight and/or biases are detached (frozen). #[derive(Debug)] @@ -27,6 +29,12 @@ impl Module for FrozenLinear { } } +impl Saveable for FrozenLinear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl LinearLayerLike for FrozenLinear { fn bias(&self) -> Option<&Tensor> { self.linear.bias() diff --git a/candle-lora/src/lib.rs b/candle-lora/src/lib.rs index 12c013a..e3a0c41 100644 --- a/candle-lora/src/lib.rs +++ b/candle-lora/src/lib.rs @@ -211,13 +211,23 @@ pub struct NewLayers { pub embed: HashMap, } +pub trait Saveable { + fn get_tensors(&self, accum: &mut HashMap); +} + /// Any layer that is linear-like. -pub trait LinearLayerLike: Module + Debug { +pub trait LinearLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn shape(&self) -> &Shape; } +impl Saveable for Linear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl LinearLayerLike for Linear { fn weight(&self) -> &Tensor { self.weight() @@ -231,12 +241,18 @@ impl LinearLayerLike for Linear { } /// Any layer that is conv1d-like. -pub trait Conv1dLayerLike: Module + Debug { +pub trait Conv1dLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv1dConfig; } +impl Saveable for Conv1d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl Conv1dLayerLike for Conv1d { fn config(&self) -> &Conv1dConfig { self.config() @@ -250,12 +266,18 @@ impl Conv1dLayerLike for Conv1d { } /// Any layer that is conv2d-like. -pub trait Conv2dLayerLike: Module + Debug { +pub trait Conv2dLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv2dConfig; } +impl Saveable for Conv2d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl Conv2dLayerLike for Conv2d { fn config(&self) -> &Conv2dConfig { self.config() @@ -269,11 +291,17 @@ impl Conv2dLayerLike for Conv2d { } /// Any layer that is embedding-like. -pub trait EmbeddingLayerLike: Module + Debug { +pub trait EmbeddingLayerLike: Module + Debug + Saveable { fn embeddings(&self) -> &Tensor; fn hidden_size(&self) -> usize; } +impl Saveable for Embedding { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl EmbeddingLayerLike for Embedding { fn embeddings(&self) -> &Tensor { self.embeddings() diff --git a/candle-lora/src/loraconv1d.rs b/candle-lora/src/loraconv1d.rs index 9bd6332..ea7e579 100644 --- a/candle-lora/src/loraconv1d.rs +++ b/candle-lora/src/loraconv1d.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraConv1d { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -73,6 +76,8 @@ impl LoraConv1d { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -155,6 +160,19 @@ impl Module for LoraConv1d { } } +impl Saveable for LoraConv1d { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a.clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b.clone(), + ); + } +} + impl Conv1dLayerLike for LoraConv1d { fn config(&self) -> &Conv1dConfig { self.old.config() diff --git a/candle-lora/src/loraconv2d.rs b/candle-lora/src/loraconv2d.rs index 09e682b..0eb699e 100644 --- a/candle-lora/src/loraconv2d.rs +++ b/candle-lora/src/loraconv2d.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenconv::FrozenConv2d, Conv2dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraConv2d { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -85,6 +88,8 @@ impl LoraConv2d { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -189,6 +194,19 @@ impl Module for LoraConv2d { } } +impl Saveable for LoraConv2d { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a_conv.weight().clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b_conv.weight().clone(), + ); + } +} + impl Conv2dLayerLike for LoraConv2d { fn config(&self) -> &Conv2dConfig { self.old.config() diff --git a/candle-lora/src/loraembed.rs b/candle-lora/src/loraembed.rs index 1350bad..bff4446 100644 --- a/candle-lora/src/loraembed.rs +++ b/candle-lora/src/loraembed.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Embedding, Init, VarBuilder}; @@ -7,7 +7,7 @@ use trc::Trc; use crate::{ frozenembed::FrozenEmbedding, EmbeddingLayerLike, LoraConfig, Merge, MergeError, - MergeErrorOrError, + MergeErrorOrError, Saveable, }; #[derive(Debug, Clone)] @@ -18,6 +18,8 @@ pub struct LoraEmbedding { b: Tensor, scale: Option, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -73,6 +75,8 @@ impl LoraEmbedding { None }, merged: false, + prefix: vb.prefix(), + id, }) } } @@ -135,6 +139,19 @@ impl Module for LoraEmbedding { } } +impl Saveable for LoraEmbedding { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a.clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b.clone(), + ); + } +} + impl EmbeddingLayerLike for LoraEmbedding { fn embeddings(&self) -> &Tensor { self.old.embeddings() diff --git a/candle-lora/src/loralinear.rs b/candle-lora/src/loralinear.rs index 2a1e05f..0605b55 100644 --- a/candle-lora/src/loralinear.rs +++ b/candle-lora/src/loralinear.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Shape, Tensor}; use candle_nn::{init, Dropout, Linear, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenlinear::FrozenLinear, LinearLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraLinear { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -65,6 +68,8 @@ impl LoraLinear { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -137,6 +142,19 @@ impl Module for LoraLinear { } } +impl Saveable for LoraLinear { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.ff_a.weight().clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.ff_b.weight().clone(), + ); + } +} + impl LinearLayerLike for LoraLinear { fn bias(&self) -> Option<&Tensor> { self.old.bias()