From dfe39c06319d1e50ec8cd53892ce8d78779c3885 Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Sat, 9 Sep 2023 20:40:30 +0000 Subject: [PATCH] [derive] Simplify code, remove obsolete features Clean up the implementation, especially in `fn impl_block`. Make the following notable changes: - Previously, `syn` didn't support parsing macro invocations in const generics without the `full` feature enabled. To avoid the compile-time overhead of that feature, we worked around it by constructing AST nodes manually. `syn` has since added support for this without requiring the `full` feature, so we make use of it. - We used to need to split types into those that transatively depended upon type generics (like `[T; 2]`) and those that didn't (like `[u8; 2]`). We made a change in #119 that made this distinction irrelevant, but we never removed the code to perform the split. In this commit, we remove that code. That code was the only reason we needed to enable `syn`'s `visit` feature, so we are also able to remove that feature dependency. --- zerocopy-derive/Cargo.toml | 2 +- zerocopy-derive/src/ext.rs | 51 +++----- zerocopy-derive/src/lib.rs | 248 ++++++++++++------------------------- 3 files changed, 101 insertions(+), 200 deletions(-) diff --git a/zerocopy-derive/Cargo.toml b/zerocopy-derive/Cargo.toml index 3d19d6ce16d..4eae69f393c 100644 --- a/zerocopy-derive/Cargo.toml +++ b/zerocopy-derive/Cargo.toml @@ -20,7 +20,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.1" quote = "1.0.10" -syn = { version = "2", features = ["visit"] } +syn = "2.0.31" [dev-dependencies] rustversion = "1.0" diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index 45b592ee693..ff8a3d6596a 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -2,34 +2,39 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Fields, Type}; +use syn::{Data, DataEnum, DataStruct, DataUnion, Type}; pub trait DataExt { - fn nested_types(&self) -> Vec<&Type>; + /// Extract the types of all fields. For enums, extract the types of fields + /// from each variant. + fn field_types(&self) -> Vec<&Type>; } impl DataExt for Data { - fn nested_types(&self) -> Vec<&Type> { + fn field_types(&self) -> Vec<&Type> { match self { - Data::Struct(strc) => strc.nested_types(), - Data::Enum(enm) => enm.nested_types(), - Data::Union(un) => un.nested_types(), + Data::Struct(strc) => strc.field_types(), + Data::Enum(enm) => enm.field_types(), + Data::Union(un) => un.field_types(), } } } impl DataExt for DataStruct { - fn nested_types(&self) -> Vec<&Type> { - fields_to_types(&self.fields) + fn field_types(&self) -> Vec<&Type> { + self.fields.iter().map(|f| &f.ty).collect() } } impl DataExt for DataEnum { - fn nested_types(&self) -> Vec<&Type> { - self.variants.iter().map(|var| fields_to_types(&var.fields)).fold(Vec::new(), |mut a, b| { - a.extend(b); - a - }) + fn field_types(&self) -> Vec<&Type> { + self.variants.iter().flat_map(|var| &var.fields).map(|f| &f.ty).collect() + } +} + +impl DataExt for DataUnion { + fn field_types(&self) -> Vec<&Type> { + self.fields.named.iter().map(|f| &f.ty).collect() } } @@ -39,24 +44,6 @@ pub trait EnumExt { impl EnumExt for DataEnum { fn is_c_like(&self) -> bool { - self.nested_types().is_empty() + self.field_types().is_empty() } } - -impl DataExt for DataUnion { - fn nested_types(&self) -> Vec<&Type> { - field_iter_to_types(&self.fields.named) - } -} - -fn fields_to_types(fields: &Fields) -> Vec<&Type> { - match fields { - Fields::Named(named) => field_iter_to_types(&named.named), - Fields::Unnamed(unnamed) => field_iter_to_types(&unnamed.unnamed), - Fields::Unit => Vec::new(), - } -} - -fn field_iter_to_types<'a, I: IntoIterator>(fields: I) -> Vec<&'a Type> { - fields.into_iter().map(|f| &f.ty).collect() -} diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 8793665cb14..6b9e3a40f24 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -30,10 +30,9 @@ mod repr; use { proc_macro2::Span, quote::quote, - syn::visit::{self, Visit}, syn::{ - parse_quote, punctuated::Punctuated, token::Comma, Data, DataEnum, DataStruct, DataUnion, - DeriveInput, Error, Expr, ExprLit, GenericParam, Ident, Lifetime, Lit, Type, TypePath, + parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, + GenericParam, Ident, Lit, }, }; @@ -122,7 +121,7 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // - all fields are `FromZeroes` fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, strct, "FromZeroes", true, None) } // An enum is `FromZeroes` if: @@ -156,21 +155,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, enm, "FromZeroes", true, None) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, PaddingCheck::None) + impl_block(ast, unn, "FromZeroes", true, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, PaddingCheck::None) + impl_block(ast, strct, "FromBytes", true, None) } // An enum is `FromBytes` if: @@ -213,7 +212,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, PaddingCheck::None) + impl_block(ast, enm, "FromBytes", true, None) } #[rustfmt::skip] @@ -244,7 +243,7 @@ const ENUM_FROM_BYTES_CFG: Config = { // - all fields are `FromBytes` fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, PaddingCheck::None) + impl_block(ast, unn, "FromBytes", true, None) } // A struct is `AsBytes` if: @@ -277,8 +276,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: // - repr(packed): Any inter-field padding bytes are removed, meaning that // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). - let padding_check = - if is_transparent || is_packed { PaddingCheck::None } else { PaddingCheck::Struct }; + let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; impl_block(ast, strct, "AsBytes", true, padding_check) } @@ -302,7 +300,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, PaddingCheck::None) + impl_block(ast, enm, "AsBytes", false, None) } #[rustfmt::skip] @@ -344,7 +342,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, PaddingCheck::Union) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union)) } // A struct is `Unaligned` if: @@ -357,7 +355,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2 let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, PaddingCheck::None) + impl_block(ast, strct, "Unaligned", require_trait_bound, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -388,7 +386,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, PaddingCheck::None) + impl_block(ast, enm, "Unaligned", true, None) } #[rustfmt::skip] @@ -426,26 +424,37 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::To let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, PaddingCheck::None) + impl_block(ast, unn, "Unaligned", require_trait_bound, None) } // This enum describes what kind of padding check needs to be generated for the // associated impl. enum PaddingCheck { - // No additional padding check is required. - None, // Check that the sum of the fields' sizes exactly equals the struct's size. Struct, // Check that the size of each field exactly equals the union's size. Union, } +impl PaddingCheck { + /// Returns the ident of the macro to call in order to validate that a type + /// passes the padding check encoded by `PaddingCheck`. + fn validator_macro_ident(&self) -> Ident { + let s = match self { + PaddingCheck::Struct => "struct_has_padding", + PaddingCheck::Union => "union_has_padding", + }; + + Ident::new(s, Span::call_site()) + } +} + fn impl_block( input: &DeriveInput, data: &D, trait_name: &str, require_trait_bound: bool, - padding_check: PaddingCheck, + padding_check: Option, ) -> proc_macro2::TokenStream { // In this documentation, we will refer to this hypothetical struct: // @@ -461,22 +470,10 @@ fn impl_block( // c: I::Item, // } // - // First, we extract the field types, which in this case are `u8`, `T`, and - // `I::Item`. We use the names of the type parameters to split the field - // types into two sets - a set of types which are based on the type - // parameters, and a set of types which are not. First, we re-use the - // existing parameters and where clauses, generating an `impl` block like: - // - // impl FromBytes for Foo - // where - // T: Copy, - // I: Clone, - // I::Item: Clone, - // { - // } - // - // Then, we use the list of types which are based on the type parameters to - // generate new entries in the `where` clause: + // We extract the field types, which in this case are `u8`, `T`, and + // `I::Item`. We re-use the existing parameters and where clauses. If + // `require_trait_bound == true` (as it is for `FromBytes), we add where + // bounds for each field's type: // // impl FromBytes for Foo // where @@ -488,18 +485,6 @@ fn impl_block( // { // } // - // Finally, we use a different technique to generate the bounds for the - // types which are not based on type parameters: - // - // - // fn only_derive_is_allowed_to_implement_this_trait() where Self: Sized { - // struct ImplementsFromBytes(PhantomData); - // let _: ImplementsFromBytes; - // } - // - // It would be easier to put all types in the where clause, but that won't - // work until the trivial_bounds feature is stabilized (#48214). - // // NOTE: It is standard practice to only emit bounds for the type parameters // themselves, not for field types based on those parameters (e.g., `T` vs // `T::Foo`). For a discussion of why this is standard practice, see @@ -521,7 +506,6 @@ fn impl_block( // b: PhantomData<&'b u8>, // } // - // // error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned` // --> src/main.rs:6:10 // | @@ -530,67 +514,37 @@ fn impl_block( // | // = note: required by `zerocopy::Unaligned` - // A visitor which is used to walk a field's type and determine whether any - // of its definition is based on the type or lifetime parameters on a type. - struct FromTypeParamVisit<'a, 'b>(&'a Punctuated, &'b mut bool); - - impl<'a, 'b> Visit<'a> for FromTypeParamVisit<'a, 'b> { - fn visit_lifetime(&mut self, i: &'a Lifetime) { - visit::visit_lifetime(self, i); - if self.0.iter().any(|param| { - if let GenericParam::Lifetime(param) = param { - param.lifetime.ident == i.ident - } else { - false - } - }) { - *self.1 = true; - } - } - - fn visit_type_path(&mut self, i: &'a TypePath) { - visit::visit_type_path(self, i); - if self.0.iter().any(|param| { - if let GenericParam::Type(param) = param { - i.path.segments.first().unwrap().ident == param.ident - } else { - false - } - }) { - *self.1 = true; - } - } - } - - // Whether this type is based on one of the type parameters. E.g., given the - // type parameters ``, `T`, `T::Foo`, and `(T::Foo, String)` are all - // based on the type parameters, while `String` and `(String, Box<()>)` are - // not. - let is_from_type_param = |ty: &Type| { - let mut ret = false; - FromTypeParamVisit(&input.generics.params, &mut ret).visit_type(ty); - ret - }; - + let type_ident = &input.ident; let trait_ident = Ident::new(trait_name, Span::call_site()); + let field_types = data.field_types(); + + let field_type_bounds = require_trait_bound + .then(|| field_types.iter().map(|ty| parse_quote!(#ty: zerocopy::#trait_ident))) + .into_iter() + .flatten() + .collect::>(); + + // Don't bother emitting a padding check if there are no fields. + #[allow(unstable_name_collisions)] // See `BoolExt` below + let padding_check_bound = padding_check.and_then(|check| (!field_types.is_empty()).then_some(check)).map(|check| { + let fields = field_types.iter(); + let validator_macro = check.validator_macro_ident(); + parse_quote!( + zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::#validator_macro!(#type_ident, #(#fields),*)}>: + zerocopy::derive_util::ShouldBe + ) + }); - let field_types = data.nested_types(); - let type_param_field_types = field_types.iter().filter(|ty| is_from_type_param(ty)); - let non_type_param_field_types = field_types.iter().filter(|ty| !is_from_type_param(ty)); - - // Add a new set of where clause predicates of the form `T: Trait` for each - // of the types of the struct's fields (but only the ones whose types are - // based on one of the type parameters). - let mut generics = input.generics.clone(); - let where_clause = generics.make_where_clause(); - if require_trait_bound { - for ty in type_param_field_types { - let bound = parse_quote!(#ty: zerocopy::#trait_ident); - where_clause.predicates.push(bound); - } - } + let bounds = input + .generics + .where_clause + .as_ref() + .map(|where_clause| where_clause.predicates.iter()) + .into_iter() + .flatten() + .chain(field_type_bounds.iter()) + .chain(padding_check_bound.iter()); - let type_ident = &input.ident; // The parameters with trait bounds, but without type defaults. let params = input.generics.params.clone().into_iter().map(|mut param| { match &mut param { @@ -610,70 +564,13 @@ fn impl_block( GenericParam::Const(cnst) => quote!(#cnst), }); - if require_trait_bound { - for ty in non_type_param_field_types { - where_clause.predicates.push(parse_quote!(#ty: zerocopy::#trait_ident)); - } - } - - match (field_types.is_empty(), padding_check) { - (true, _) | (false, PaddingCheck::None) => (), - (false, PaddingCheck::Struct) => { - let fields = field_types.iter(); - // `parse_quote!` doesn't parse macro invocations in const generics - // properly without enabling syn's `full` feature, so the type has - // to be manually constructed as `syn::Type::Verbatim`. - // - // This where clause is equivalent to adding: - // ``` - // HasPadding: ShouldBe - // ``` - // with fully-qualified paths. - where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType { - lifetimes: None, - bounded_ty: syn::Type::Verbatim(quote!(zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::struct_has_padding!(#type_ident, #(#fields),*)}>)), - colon_token: syn::Token![:](Span::mixed_site()), - bounds: parse_quote!(zerocopy::derive_util::ShouldBe), - })); - } - (false, PaddingCheck::Union) => { - let fields = field_types.iter(); - // `parse_quote!` doesn't parse macro invocations in const generics - // properly without enabling syn's `full` feature, so the type has - // to be manually constructed as `syn::Type::Verbatim`. - // - // This where clause is equivalent to adding: - // ``` - // HasPadding: ShouldBe - // ``` - // with fully-qualified paths. - where_clause.predicates.push(syn::WherePredicate::Type(syn::PredicateType { - lifetimes: None, - bounded_ty: syn::Type::Verbatim(quote!(zerocopy::derive_util::HasPadding<#type_ident, {zerocopy::union_has_padding!(#type_ident, #(#fields),*)}>)), - colon_token: syn::Token![:](Span::mixed_site()), - bounds: parse_quote!(zerocopy::derive_util::ShouldBe), - })); - } - } - - // We use a constant to force the compiler to emit an error when a concrete - // type does not satisfy the where clauses on its impl. - let use_concrete = if input.generics.params.is_empty() { - Some(quote! { - const _: () = { - fn must_implement_trait() {} - let _ = must_implement_trait::<#type_ident>; - }; - }) - } else { - None - }; - quote! { - unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > #where_clause { + unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > + where + #(#bounds,)* + { fn only_derive_is_allowed_to_implement_this_trait() {} } - #use_concrete } } @@ -681,6 +578,23 @@ fn print_all_errors(errors: Vec) -> proc_macro2::TokenStream { errors.iter().map(Error::to_compile_error).collect() } +// A polyfill for `Option::then_some`, which was added after our MSRV. +// +// TODO(#67): Remove this once our MSRV is >= 1.62. +trait BoolExt { + fn then_some(self, t: T) -> Option; +} + +impl BoolExt for bool { + fn then_some(self, t: T) -> Option { + if self { + Some(t) + } else { + None + } + } +} + #[cfg(test)] mod tests { use super::*;