Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derive(Zeroable) on fieldful enums and repr(C) enums #257

Merged
merged 9 commits into from
Sep 24, 2024
74 changes: 60 additions & 14 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,26 @@ pub fn derive_anybitpattern(
proc_macro::TokenStream::from(expanded)
}

/// Derive the `Zeroable` trait for a struct
/// Derive the `Zeroable` trait for a type.
///
/// The macro ensures that the struct follows all the the safety requirements
/// The macro ensures that the type follows all the the safety requirements
/// for the `Zeroable` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
/// The following constraints need to be satisfied for the macro to succeed on a
/// struct:
///
/// - All fields in the struct must implement `Zeroable`
///
/// The following constraints need to be satisfied for the macro to succeed on
/// an enum:
///
/// - All fields in the struct must to implement `Zeroable`
/// - The enum has an explicit `#[repr(Int)]`, `#[repr(C)]`, or `#[repr(C,
/// Int)]`.
/// - The enum has a variant with discriminant 0 (explicitly or implicitly).
/// - All fields in the variant with discriminant 0 (if any) must implement
/// `Zeroable`
///
/// The macro always succeeds on unions.
///
/// ## Example
///
Expand All @@ -134,6 +146,23 @@ pub fn derive_anybitpattern(
/// b: u16,
/// }
/// ```
/// ```rust
/// # use bytemuck_derive::{Zeroable};
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(i32)]
/// enum Values {
/// A = 0,
/// B = 1,
/// C = 2,
/// }
/// #[derive(Clone, Zeroable)]
/// #[repr(C)]
/// enum Implicit {
/// A(bool, u8, char),
/// B(String),
/// C(std::num::NonZeroU8),
/// }
/// ```
///
/// # Custom bounds
///
Expand All @@ -157,6 +186,18 @@ pub fn derive_anybitpattern(
///
/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
/// ```
/// ```rust
/// # use bytemuck::{Zeroable};
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(u8)]
/// #[zeroable(bound = "")]
/// enum MyOption<T> {
/// None,
/// Some(T),
/// }
///
/// assert!(matches!(MyOption::<std::num::NonZeroU8>::zeroed(), MyOption::None));
/// ```
///
/// ```rust,compile_fail
/// # use bytemuck::Zeroable;
Expand Down Expand Up @@ -407,7 +448,8 @@ pub fn derive_byte_eq(
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
Expand Down Expand Up @@ -460,7 +502,8 @@ pub fn derive_byte_hash(
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
Expand Down Expand Up @@ -569,26 +612,29 @@ fn derive_marker_trait_inner<Trait: Derivable>(
.flatten()
.collect::<Vec<syn::WherePredicate>>();

let predicates = &mut input.generics.make_where_clause().predicates;

predicates.extend(explicit_bounds);

let fields = match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
syn::Data::Union(_) => {
let fields = match (Trait::perfect_derive_fields(&input), &input.data) {
(Some(fields), _) => fields,
(None, syn::Data::Struct(syn::DataStruct { fields, .. })) => {
fields.clone()
}
(None, syn::Data::Union(_)) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for unions",
));
}
syn::Data::Enum(_) => {
(None, syn::Data::Enum(_)) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for enums",
));
}
};

let predicates = &mut input.generics.make_where_clause().predicates;

predicates.extend(explicit_bounds);

for field in fields {
let ty = field.ty;
predicates.push(syn::parse_quote!(
Expand Down
Loading
Loading