From 239c21ddee54b1904e23d512f7850b9c3debb3cf Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 14 Sep 2023 10:54:34 -0400 Subject: [PATCH] Fixing clippy errors --- dfdx-nn-core/src/lib.rs | 14 +++--- dfdx-nn-core/src/tuples.rs | 4 +- dfdx-nn-core/src/vecs.rs | 4 +- dfdx-nn-derives/src/lib.rs | 72 ++++++++---------------------- dfdx-nn/src/layers/add_into.rs | 2 + dfdx-nn/src/layers/conv1d.rs | 1 + dfdx-nn/src/layers/conv2d.rs | 1 + dfdx-nn/src/layers/conv_trans2d.rs | 1 + 8 files changed, 34 insertions(+), 65 deletions(-) diff --git a/dfdx-nn-core/src/lib.rs b/dfdx-nn-core/src/lib.rs index beb3bccc..04b4cb61 100644 --- a/dfdx-nn-core/src/lib.rs +++ b/dfdx-nn-core/src/lib.rs @@ -153,7 +153,7 @@ pub trait SaveSafeTensors { let data = tensors.iter().map(|(k, dtype, shape, data)| { ( k.clone(), - safetensors::tensor::TensorView::new(dtype.clone(), shape.clone(), data).unwrap(), + safetensors::tensor::TensorView::new(*dtype, shape.clone(), data).unwrap(), ) }); @@ -178,18 +178,18 @@ pub trait LoadSafeTensors { self.read_safetensors("", &tensors) } - fn read_safetensors<'a>( + fn read_safetensors( &mut self, location: &str, - tensors: &safetensors::SafeTensors<'a>, + tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError>; } impl, T> LoadSafeTensors for Tensor { - fn read_safetensors<'a>( + fn read_safetensors( &mut self, location: &str, - tensors: &safetensors::SafeTensors<'a>, + tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { self.load_safetensor(tensors, location) } @@ -230,10 +230,10 @@ macro_rules! unit_safetensors { } impl LoadSafeTensors for $Ty { - fn read_safetensors<'a>( + fn read_safetensors( &mut self, location: &str, - tensors: &safetensors::SafeTensors<'a>, + tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { #[allow(unused_imports)] use dfdx::dtypes::FromLeBytes; diff --git a/dfdx-nn-core/src/tuples.rs b/dfdx-nn-core/src/tuples.rs index f7932438..ad446bd6 100644 --- a/dfdx-nn-core/src/tuples.rs +++ b/dfdx-nn-core/src/tuples.rs @@ -22,10 +22,10 @@ macro_rules! tuple_impls { } impl<$($name: crate::LoadSafeTensors, )+> crate::LoadSafeTensors for ($($name,)+) { - fn read_safetensors<'a>( + fn read_safetensors( &mut self, location: &str, - tensors: &safetensors::SafeTensors<'a>, + tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { $(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+ Ok(()) diff --git a/dfdx-nn-core/src/vecs.rs b/dfdx-nn-core/src/vecs.rs index 1455b5f1..cbe7fc74 100644 --- a/dfdx-nn-core/src/vecs.rs +++ b/dfdx-nn-core/src/vecs.rs @@ -54,10 +54,10 @@ impl crate::SaveSafeTensors for Vec { } impl crate::LoadSafeTensors for Vec { - fn read_safetensors<'a>( + fn read_safetensors( &mut self, location: &str, - tensors: &safetensors::SafeTensors<'a>, + tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { t.read_safetensors(&format!("{location}{i}."), tensors)?; diff --git a/dfdx-nn-derives/src/lib.rs b/dfdx-nn-derives/src/lib.rs index 89fd272e..78bec0d4 100644 --- a/dfdx-nn-derives/src/lib.rs +++ b/dfdx-nn-derives/src/lib.rs @@ -544,29 +544,17 @@ pub fn reset_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let name = input.ident; let mut custom_generics = input.generics.clone(); - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"), + ) { custom_generics .params .push(parse_quote!(Elem: dfdx::prelude::Dtype)); } - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"), + ) { custom_generics .params .push(parse_quote!(Dev: dfdx::prelude::Device)); @@ -631,29 +619,17 @@ pub fn update_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream let struct_name = input.ident; let mut custom_generics = input.generics.clone(); - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"), + ) { custom_generics .params .push(parse_quote!(Elem: dfdx::prelude::Dtype)); } - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"), + ) { custom_generics .params .push(parse_quote!(Dev: dfdx::prelude::Device)); @@ -727,29 +703,17 @@ pub fn zero_grads(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let name = input.ident; let mut custom_generics = input.generics.clone(); - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Elem" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Elem"), + ) { custom_generics .params .push(parse_quote!(Elem: dfdx::prelude::Dtype)); } - if custom_generics - .params - .iter() - .position(|param| match param { - syn::GenericParam::Type(type_param) if type_param.ident == "Dev" => true, - _ => false, - }) - .is_none() - { + if !custom_generics.params.iter().any( + |param| matches!(param, syn::GenericParam::Type(type_param) if type_param.ident == "Dev"), + ) { custom_generics .params .push(parse_quote!(Dev: dfdx::prelude::Device)); diff --git a/dfdx-nn/src/layers/add_into.rs b/dfdx-nn/src/layers/add_into.rs index e830a14a..71afc2ba 100644 --- a/dfdx-nn/src/layers/add_into.rs +++ b/dfdx-nn/src/layers/add_into.rs @@ -59,6 +59,7 @@ macro_rules! add_into_impls { type Output = Out; type Error = A::Error; + #[allow(clippy::needless_question_mark)] fn try_forward(&self, x: (Ai, $($Inp, )+)) -> Result { let (a, $($ModVar, )+) = &self.0; let (a_i, $($InpVar, )+) = x; @@ -66,6 +67,7 @@ macro_rules! add_into_impls { $(let $InpVar = $ModVar.try_forward($InpVar)?;)+ Ok(sum!(a_i, $($InpVar),*)) } + #[allow(clippy::needless_question_mark)] fn try_forward_mut(&mut self, x: (Ai, $($Inp, )+)) -> Result { let (a, $($ModVar, )+) = &mut self.0; let (a_i, $($InpVar, )+) = x; diff --git a/dfdx-nn/src/layers/conv1d.rs b/dfdx-nn/src/layers/conv1d.rs index 65d6aea7..cf7998ba 100644 --- a/dfdx-nn/src/layers/conv1d.rs +++ b/dfdx-nn/src/layers/conv1d.rs @@ -97,6 +97,7 @@ where { #[param] #[serialize] + #[allow(clippy::type_complexity)] pub weight: Tensor< ( OutChan, diff --git a/dfdx-nn/src/layers/conv2d.rs b/dfdx-nn/src/layers/conv2d.rs index 32e4bfc0..b7e98294 100644 --- a/dfdx-nn/src/layers/conv2d.rs +++ b/dfdx-nn/src/layers/conv2d.rs @@ -118,6 +118,7 @@ where { #[param] #[serialize] + #[allow(clippy::type_complexity)] pub weight: Tensor< ( OutChan, diff --git a/dfdx-nn/src/layers/conv_trans2d.rs b/dfdx-nn/src/layers/conv_trans2d.rs index 12eac0ad..0b23614f 100644 --- a/dfdx-nn/src/layers/conv_trans2d.rs +++ b/dfdx-nn/src/layers/conv_trans2d.rs @@ -96,6 +96,7 @@ where { #[param] #[serialize] + #[allow(clippy::type_complexity)] pub weight: Tensor< ( InChan,