Skip to content

Commit

Permalink
Apply send, sync bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 18, 2024
1 parent 86190da commit 0584189
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 108 deletions.
196 changes: 95 additions & 101 deletions candle-lora-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,127 +9,121 @@ use syn::{
pub fn replace_layer_fields(_args: TokenStream1, input: TokenStream1) -> TokenStream1 {
let mut ast = parse_macro_input!(input as DeriveInput);
match &mut ast.data {
Data::Struct(ref mut struct_data) => {
match &mut struct_data.fields {
Fields::Named(fields) => {
for field in fields.named.iter_mut() {
let mut f = None;
let ident = field.ident.clone().unwrap();
let ty = field.ty.clone();
if let Type::Path(path) = ty {
if path.path.segments.len() == 1 {
match path
.path
.segments
.first()
.unwrap()
.ident
.to_string()
.as_str()
{
"Linear" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn LinearLayerLike>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn LinearLayerLike>)).unwrap());
}
Data::Struct(ref mut struct_data) => match &mut struct_data.fields {
Fields::Named(fields) => {
for field in fields.named.iter_mut() {
let mut f = None;
let ident = field.ident.clone().unwrap();
let ty = field.ty.clone();
if let Type::Path(path) = ty {
if path.path.segments.len() == 1 {
match path
.path
.segments
.first()
.unwrap()
.ident
.to_string()
.as_str()
{
"Linear" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn LinearLayerLike + Send + Sync>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn LinearLayerLike + Send + Sync>)).unwrap());
}
"Conv1d" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn Conv1dLayerLike>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn Conv1dLayerLike>)).unwrap());
}
}
"Conv1d" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn Conv1dLayerLike + Send + Sync + Send + Sync>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn Conv1dLayerLike + Send + Sync>)).unwrap());
}
"Conv2d" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn Conv2dLayerLike>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn Conv2dLayerLike>)).unwrap());
}
}
"Conv2d" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn Conv2dLayerLike + Send + Sync>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn Conv2dLayerLike + Send + Sync>)).unwrap());
}
"Embedding" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn EmbeddingLayerLike>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn EmbeddingLayerLike>)).unwrap());
}
}
"Embedding" => {
if let Visibility::Public(_) = field.vis {
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc<dyn EmbeddingLayerLike + Send + Sync>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc<dyn EmbeddingLayerLike + Send + Sync>)).unwrap());
}
"Option" => {
if let PathArguments::AngleBracketed(bracketed) =
&path.path.segments.first().unwrap().arguments
{
if bracketed.args.len() == 1 {
if let GenericArgument::Type(Type::Path(tp)) =
bracketed.args.first().unwrap()
{
if tp.path.segments.len() == 1 {
match tp
.path
.segments
.first()
.unwrap()
.ident
.to_string()
.as_str()
{
"Linear" => {
if let Visibility::Public(_) =
field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn LinearLayerLike>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn LinearLayerLike>>)).unwrap());
}
}
"Option" => {
if let PathArguments::AngleBracketed(bracketed) =
&path.path.segments.first().unwrap().arguments
{
if bracketed.args.len() == 1 {
if let GenericArgument::Type(Type::Path(tp)) =
bracketed.args.first().unwrap()
{
if tp.path.segments.len() == 1 {
match tp
.path
.segments
.first()
.unwrap()
.ident
.to_string()
.as_str()
{
"Linear" => {
if let Visibility::Public(_) = field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn LinearLayerLike + Send + Sync>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn LinearLayerLike + Send + Sync>>)).unwrap());
}
"Conv1d" => {
if let Visibility::Public(_) =
field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn Conv1dLayerLike>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn Conv1dLayerLike>>)).unwrap());
}
}
"Conv1d" => {
if let Visibility::Public(_) = field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn Conv1dLayerLike + Send + Sync>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn Conv1dLayerLike + Send + Sync>>)).unwrap());
}
"Conv2d" => {
if let Visibility::Public(_) =
field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn Conv2dLayerLike>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn Conv2dLayerLike>>)).unwrap());
}
}
"Conv2d" => {
if let Visibility::Public(_) = field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn Conv2dLayerLike + Send + Sync>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn Conv2dLayerLike + Send + Sync>>)).unwrap());
}
"Embedding" => {
if let Visibility::Public(_) =
field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn EmbeddingLayerLike>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn EmbeddingLayerLike>>)).unwrap());
}
}
"Embedding" => {
if let Visibility::Public(_) = field.vis
{
f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option<Arc<dyn EmbeddingLayerLike + Send + Sync>>)).unwrap());
} else {
f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option<Arc<dyn EmbeddingLayerLike + Send + Sync>>)).unwrap());
}
_ => {}
}
_ => {}
}
}
}
}
}
_ => {}
}
_ => {}
}
}
if let Some(f) = f {
*field = f;
}
}
}
_ => {
panic!("Named fields are required.")
if let Some(f) = f {
*field = f;
}
}
}
}
_ => {
panic!("Named fields are required.")
}
},
_ => {
panic!("Cannot swap fields of non struct!");
}
Expand Down
2 changes: 1 addition & 1 deletion candle-lora-transformers/src/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ struct BertEmbedding {
}

impl Deref for BertEmbedding {
type Target = Arc<dyn EmbeddingLayerLike>;
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/bigcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct CustomLinear {
}

impl Deref for CustomLinear {
type Target = Arc<dyn LinearLayerLike>;
type Target = Arc<dyn LinearLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand All @@ -30,7 +30,7 @@ struct CustomEmbedding {
}

impl Deref for CustomEmbedding {
type Target = Arc<dyn EmbeddingLayerLike>;
type Target = Arc<dyn EmbeddingLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/dinov2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DinoLinear {
}

impl Deref for DinoLinear {
type Target = Arc<dyn LinearLayerLike>;
type Target = Arc<dyn LinearLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down Expand Up @@ -290,7 +290,7 @@ struct DinoConv2d {
}

impl Deref for DinoConv2d {
type Target = Arc<dyn Conv2dLayerLike>;
type Target = Arc<dyn Conv2dLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.inner
Expand Down
4 changes: 2 additions & 2 deletions candle-lora-transformers/src/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ struct AttentionDense {
}

impl Deref for AttentionQKV {
type Target = Arc<dyn LinearLayerLike>;
type Target = Arc<dyn LinearLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.query_key_value
}
}

impl Deref for AttentionDense {
type Target = Arc<dyn LinearLayerLike>;
type Target = Arc<dyn LinearLayerLike + Send + Sync>;

fn deref(&self) -> &Self::Target {
&self.dense
Expand Down

0 comments on commit 0584189

Please sign in to comment.