Skip to content

Commit

Permalink
Merge pull request #2795 from Mingun/has-flatten-rework
Browse files Browse the repository at this point in the history
`has_flatten` rework
  • Loading branch information
dtolnay authored Aug 12, 2024
2 parents 85c73ef + 77a6a9d commit f986609
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 114 deletions.
42 changes: 17 additions & 25 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,11 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
} else if let attr::Identifier::No = cont.attrs.identifier() {
match &cont.data {
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
Data::Struct(Style::Struct, fields) => deserialize_struct(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
StructForm::Struct,
),
Data::Struct(Style::Struct, fields) => {
deserialize_struct(params, fields, &cont.attrs, StructForm::Struct)
}
Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => {
deserialize_tuple(
params,
fields,
&cont.attrs,
cont.attrs.has_flatten(),
TupleForm::Tuple,
)
deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple)
}
Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs),
}
Expand Down Expand Up @@ -469,11 +459,10 @@ fn deserialize_tuple(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: TupleForm,
) -> Fragment {
assert!(
!has_flatten,
!has_flatten(fields),
"tuples and tuple variants cannot have flatten fields"
);

Expand Down Expand Up @@ -594,7 +583,7 @@ fn deserialize_tuple_in_place(
cattrs: &attr::Container,
) -> Fragment {
assert!(
!cattrs.has_flatten(),
!has_flatten(fields),
"tuples and tuple variants cannot have flatten fields"
);

Expand Down Expand Up @@ -927,7 +916,6 @@ fn deserialize_struct(
params: &Parameters,
fields: &[Field],
cattrs: &attr::Container,
has_flatten: bool,
form: StructForm,
) -> Fragment {
let this_type = &params.this_type;
Expand Down Expand Up @@ -976,6 +964,8 @@ fn deserialize_struct(
)
})
.collect();

let has_flatten = has_flatten(fields);
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten);

// untagged struct variants do not get a visit_seq method. The same applies to
Expand Down Expand Up @@ -1115,7 +1105,7 @@ fn deserialize_struct_in_place(
) -> Option<Fragment> {
// for now we do not support in_place deserialization for structs that
// are represented as map.
if cattrs.has_flatten() {
if has_flatten(fields) {
return None;
}

Expand Down Expand Up @@ -1831,14 +1821,12 @@ fn deserialize_externally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::ExternallyTagged(variant_ident),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::ExternallyTagged(variant_ident),
),
}
Expand Down Expand Up @@ -1882,7 +1870,6 @@ fn deserialize_internally_tagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::InternallyTagged(variant_ident, deserializer),
),
Style::Tuple => unreachable!("checked in serde_derive_internals"),
Expand Down Expand Up @@ -1933,14 +1920,12 @@ fn deserialize_untagged_variant(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
TupleForm::Untagged(variant_ident, deserializer),
),
Style::Struct => deserialize_struct(
params,
&variant.fields,
cattrs,
variant.attrs.has_flatten(),
StructForm::Untagged(variant_ident, deserializer),
),
}
Expand Down Expand Up @@ -2707,7 +2692,7 @@ fn deserialize_map_in_place(
cattrs: &attr::Container,
) -> Fragment {
assert!(
!cattrs.has_flatten(),
!has_flatten(fields),
"inplace deserialization of maps does not support flatten fields"
);

Expand Down Expand Up @@ -3042,6 +3027,13 @@ fn effective_style(variant: &Variant) -> Style {
}
}

/// True if there are fields that is not skipped and has a `#[serde(flatten)]` attribute.
fn has_flatten(fields: &[Field]) -> bool {
fields
.iter()
.any(|field| field.attrs.flatten() && !field.attrs.skip_deserializing())
}

struct DeImplGenerics<'a>(&'a Parameters);
#[cfg(feature = "deserialize_in_place")]
struct InPlaceImplGenerics<'a>(&'a Parameters);
Expand Down
14 changes: 1 addition & 13 deletions serde_derive/src/internals/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<'a> Container<'a> {
item: &'a syn::DeriveInput,
derive: Derive,
) -> Option<Container<'a>> {
let mut attrs = attr::Container::from_ast(cx, item);
let attrs = attr::Container::from_ast(cx, item);

let mut data = match &item.data {
syn::Data::Enum(data) => Data::Enum(enum_from_ast(cx, &data.variants, attrs.default())),
Expand All @@ -77,16 +77,11 @@ impl<'a> Container<'a> {
}
};

let mut has_flatten = false;
match &mut data {
Data::Enum(variants) => {
for variant in variants {
variant.attrs.rename_by_rules(attrs.rename_all_rules());
for field in &mut variant.fields {
if field.attrs.flatten() {
has_flatten = true;
variant.attrs.mark_has_flatten();
}
field.attrs.rename_by_rules(
variant
.attrs
Expand All @@ -98,18 +93,11 @@ impl<'a> Container<'a> {
}
Data::Struct(_, fields) => {
for field in fields {
if field.attrs.flatten() {
has_flatten = true;
}
field.attrs.rename_by_rules(attrs.rename_all_rules());
}
}
}

if has_flatten {
attrs.mark_has_flatten();
}

let mut item = Container {
ident: item.ident.clone(),
attrs,
Expand Down
47 changes: 0 additions & 47 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,6 @@ pub struct Container {
type_into: Option<syn::Type>,
remote: Option<syn::Path>,
identifier: Identifier,
/// True if container is a struct and has a field with `#[serde(flatten)]`,
/// or is an enum with a struct variant which has a field with
/// `#[serde(flatten)]`.
///
/// ```ignore
/// struct Container {
/// #[serde(flatten)]
/// some_field: (),
/// }
/// enum Container {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
serde_path: Option<syn::Path>,
is_packed: bool,
/// Error message generated when type can't be deserialized
Expand Down Expand Up @@ -603,7 +586,6 @@ impl Container {
type_into: type_into.get(),
remote: remote.get(),
identifier: decide_identifier(cx, item, field_identifier, variant_identifier),
has_flatten: false,
serde_path: serde_path.get(),
is_packed,
expecting: expecting.get(),
Expand Down Expand Up @@ -671,14 +653,6 @@ impl Container {
self.identifier
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn custom_serde_path(&self) -> Option<&syn::Path> {
self.serde_path.as_ref()
}
Expand Down Expand Up @@ -810,18 +784,6 @@ pub struct Variant {
rename_all_rules: RenameAllRules,
ser_bound: Option<Vec<syn::WherePredicate>>,
de_bound: Option<Vec<syn::WherePredicate>>,
/// True if variant is a struct variant which contains a field with
/// `#[serde(flatten)]`.
///
/// ```ignore
/// enum Enum {
/// Variant {
/// #[serde(flatten)]
/// some_field: (),
/// },
/// }
/// ```
has_flatten: bool,
skip_deserializing: bool,
skip_serializing: bool,
other: bool,
Expand Down Expand Up @@ -991,7 +953,6 @@ impl Variant {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
has_flatten: false,
skip_deserializing: skip_deserializing.get(),
skip_serializing: skip_serializing.get(),
other: other.get(),
Expand Down Expand Up @@ -1034,14 +995,6 @@ impl Variant {
self.de_bound.as_ref().map(|vec| &vec[..])
}

pub fn has_flatten(&self) -> bool {
self.has_flatten
}

pub fn mark_has_flatten(&mut self) {
self.has_flatten = true;
}

pub fn skip_deserializing(&self) -> bool {
self.skip_deserializing
}
Expand Down
33 changes: 12 additions & 21 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,18 @@ fn serialize_tuple_struct(
}

fn serialize_struct(params: &Parameters, fields: &[Field], cattrs: &attr::Container) -> Fragment {
assert!(fields.len() as u64 <= u64::from(u32::MAX));
assert!(
fields.len() as u64 <= u64::from(u32::MAX),
"too many fields in {}: {}, maximum supported count is {}",
cattrs.name().serialize_name(),
fields.len(),
u32::MAX
);

if cattrs.has_flatten() {
let has_non_skipped_flatten = fields
.iter()
.any(|field| field.attrs.flatten() && !field.attrs.skip_serializing());
if has_non_skipped_flatten {
serialize_struct_as_map(params, fields, cattrs)
} else {
serialize_struct_as_struct(params, fields, cattrs)
Expand Down Expand Up @@ -370,26 +379,8 @@ fn serialize_struct_as_map(

let let_mut = mut_if(serialized_fields.peek().is_some() || tag_field_exists);

let len = if cattrs.has_flatten() {
quote!(_serde::__private::None)
} else {
let len = serialized_fields
.map(|field| match field.attrs.skip_serializing_if() {
None => quote!(1),
Some(path) => {
let field_expr = get_member(params, field, &field.member);
quote!(if #path(#field_expr) { 0 } else { 1 })
}
})
.fold(
quote!(#tag_field_exists as usize),
|sum, expr| quote!(#sum + #expr),
);
quote!(_serde::__private::Some(#len))
};

quote_block! {
let #let_mut __serde_state = _serde::Serializer::serialize_map(__serializer, #len)?;
let #let_mut __serde_state = _serde::Serializer::serialize_map(__serializer, _serde::__private::None)?;
#tag_field
#(#serialize_fields)*
_serde::ser::SerializeMap::end(__serde_state)
Expand Down
32 changes: 24 additions & 8 deletions test_suite/tests/test_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,32 @@ fn test_gen() {
}
assert::<FlattenWith>();

#[derive(Serialize, Deserialize)]
pub struct Flatten<T> {
#[serde(flatten)]
t: T,
}

#[derive(Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FlattenDenyUnknown<T> {
#[serde(flatten)]
t: T,
}

#[derive(Serialize, Deserialize)]
pub struct SkipDeserializing<T> {
#[serde(skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SkipDeserializingDenyUnknown<T> {
#[serde(skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
pub struct StaticStrStruct<'a> {
a: &'a str,
Expand Down Expand Up @@ -720,14 +739,11 @@ fn test_gen() {
flat: StdOption<T>,
}

#[allow(clippy::collection_is_never_read)] // FIXME
const _: () = {
#[derive(Serialize, Deserialize)]
pub struct FlattenSkipDeserializing<T> {
#[serde(flatten, skip_deserializing)]
flat: T,
}
};
#[derive(Serialize, Deserialize)]
pub struct FlattenSkipDeserializing<T> {
#[serde(flatten, skip_deserializing)]
flat: T,
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
Expand Down

0 comments on commit f986609

Please sign in to comment.