Skip to content

Commit

Permalink
Add support for serde flatten (#325)
Browse files Browse the repository at this point in the history
Add support for serde `flatten` attribute in `ToSchema` derive macro. 
Add tests and update docs.

Fixes #120
  • Loading branch information
kw7oe committed Oct 28, 2022
1 parent b8cdfe4 commit 0332914
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 44 deletions.
128 changes: 84 additions & 44 deletions utoipa-gen/src/component/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,26 @@ struct NamedStructSchema<'a> {
impl ToTokens for NamedStructSchema<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let container_rules = serde::parse_container(self.attributes);
let mut object_tokens = quote! { utoipa::openapi::ObjectBuilder::new() };

tokens.extend(quote! { utoipa::openapi::ObjectBuilder::new() });
let flatten_fields: Vec<&Field> = self
.fields
.iter()
.filter(|field| {
let field_rule = serde::parse_value(&field.attrs);
is_flatten(&field_rule)
})
.collect();

self.fields
.iter()
.filter_map(|field| {
let field_rule = serde::parse_value(&field.attrs);

if is_flatten(&field_rule) {
return None;
};

if is_not_skipped(&field_rule) {
Some((field, field_rule))
} else {
Expand All @@ -218,53 +230,36 @@ impl ToTokens for NamedStructSchema<'_> {
let name = &rename_field(&container_rules, &mut field_rule, field_name)
.unwrap_or_else(|| String::from(field_name));

let type_tree = &mut TypeTree::from_type(&field.ty);

if let Some((generic_types, alias)) = self.generics.zip(self.alias) {
generic_types
.type_params()
.enumerate()
.for_each(|(index, generic)| {
if let Some(generic_type) = type_tree.find_mut_by_ident(&generic.ident)
{
generic_type.update_path(
&alias.generics.type_params().nth(index).unwrap().ident,
);
};
})
}

let deprecated = super::get_deprecated(&field.attrs);
let attrs =
SchemaAttr::<NamedField>::from_attributes_validated(&field.attrs, type_tree);

let override_type_tree = attrs
.as_ref()
.and_then(|field| field.as_ref().value_type.as_ref().map(TypeTree::from_type));
with_field_as_schema_property(self, field, |schema_property| {
object_tokens.extend(quote! {
.property(#name, #schema_property)
});

let xml_value = attrs
.as_ref()
.and_then(|named_field| named_field.as_ref().xml.as_ref());
let comments = CommentAttributes::from_attributes(&field.attrs);
if !schema_property.is_option() && !is_default(&container_rules, &field_rule) {
object_tokens.extend(quote! {
.required(#name)
})
}
})
});

let schema_property = SchemaProperty::new(
override_type_tree.as_ref().unwrap_or(type_tree),
Some(&comments),
attrs.as_ref(),
deprecated.as_ref(),
xml_value,
);
if !flatten_fields.is_empty() {
tokens.extend(quote! {
utoipa::openapi::AllOfBuilder::new()
});

tokens.extend(quote! {
.property(#name, #schema_property)
});
for field in flatten_fields {
with_field_as_schema_property(self, field, |schema_property| {
tokens.extend(quote! { .item(#schema_property) });
})
}

if !schema_property.is_option() && !is_default(&container_rules, &field_rule) {
tokens.extend(quote! {
.required(#name)
})
}
});
tokens.extend(quote! {
.item(#object_tokens)
})
} else {
tokens.extend(object_tokens)
}

if let Some(deprecated) = super::get_deprecated(self.attributes) {
tokens.extend(quote! { .deprecated(Some(#deprecated)) });
Expand All @@ -283,6 +278,44 @@ impl ToTokens for NamedStructSchema<'_> {
}
}

fn with_field_as_schema_property<R>(
schema: &NamedStructSchema,
field: &Field,
yield_: impl FnOnce(SchemaProperty<'_, NamedField<'_>>) -> R,
) -> R {
let type_tree = &mut TypeTree::from_type(&field.ty);

if let Some((generic_types, alias)) = schema.generics.zip(schema.alias) {
generic_types
.type_params()
.enumerate()
.for_each(|(index, generic)| {
if let Some(generic_type) = type_tree.find_mut_by_ident(&generic.ident) {
generic_type
.update_path(&alias.generics.type_params().nth(index).unwrap().ident);
};
})
}

let deprecated = super::get_deprecated(&field.attrs);
let attrs = SchemaAttr::<NamedField>::from_attributes_validated(&field.attrs, type_tree);
let override_type_tree = attrs
.as_ref()
.and_then(|field| field.as_ref().value_type.as_ref().map(TypeTree::from_type));
let xml_value = attrs
.as_ref()
.and_then(|named_field| named_field.as_ref().xml.as_ref());
let comments = CommentAttributes::from_attributes(&field.attrs);

yield_(SchemaProperty::new(
override_type_tree.as_ref().unwrap_or(type_tree),
Some(&comments),
attrs.as_ref(),
deprecated.as_ref(),
xml_value,
))
}

#[inline]
fn is_default(container_rules: &Option<SerdeContainer>, field_rule: &Option<SerdeValue>) -> bool {
*container_rules
Expand Down Expand Up @@ -1154,6 +1187,13 @@ fn is_not_skipped(rule: &Option<SerdeValue>) -> bool {
.unwrap_or(true)
}

#[inline]
fn is_flatten(rule: &Option<SerdeValue>) -> bool {
rule.as_ref()
.map(|value| value.flatten.is_some())
.unwrap_or(false)
}

/// Resolves the appropriate [`RenameRule`] to apply to the specified `struct` `field` name given a
/// `container_rule` (`struct` or `enum` level) and `field_rule` (`struct` field or `enum` variant
/// level). Returns `Some` of the result of the `rename_op` if a rename is required by the supplied
Expand Down
2 changes: 2 additions & 0 deletions utoipa-gen/src/component/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct SerdeValue {
pub skip: Option<bool>,
pub rename: Option<String>,
pub default: Option<bool>,
pub flatten: Option<bool>,
}

impl SerdeValue {
Expand All @@ -37,6 +38,7 @@ impl SerdeValue {
while let Some((tt, next)) = rest.token_tree() {
match tt {
TokenTree::Ident(ident) if ident == "skip" => value.skip = Some(true),
TokenTree::Ident(ident) if ident == "flatten" => value.flatten = Some(true),
TokenTree::Ident(ident) if ident == "rename" => {
if let Some((literal, _)) = parse_next_lit_str(next) {
value.rename = Some(literal)
Expand Down
1 change: 1 addition & 0 deletions utoipa-gen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ use ext::ArgumentResolver;
/// * `skip = "..."` Supported **only** in field or variant level.
/// * `tag = "..."` Supported in container level. `tag` attribute also works as a [discriminator field][discriminator] for an enum.
/// * `default` Supported in container level and field level according to [serde attributes].
/// * `flatten` Supported in field level.
///
/// Other _`serde`_ attributes works as is but does not have any effect on the generated OpenAPI doc.
///
Expand Down
101 changes: 101 additions & 0 deletions utoipa-gen/tests/schema_derive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,107 @@ fn derive_complex_enum_serde_tag() {
);
}

#[test]
fn derive_serde_flatten() {
#[derive(Serialize)]
struct Metadata {
category: String,
total: u64,
}

#[derive(Serialize)]
struct Record {
amount: i64,
description: String,
#[serde(flatten)]
metadata: Metadata,
}

#[derive(Serialize)]
struct Pagination {
page: i64,
next_page: i64,
per_page: i64,
}

// Single flatten field
let value: Value = api_doc! {
#[derive(Serialize)]
struct Record {
amount: i64,
description: String,
#[serde(flatten)]
metadata: Metadata,
}
};

assert_json_eq!(
value,
json!({
"allOf": [
{
"$ref": "#/components/schemas/Metadata"
},
{
"type": "object",
"properties": {
"amount": {
"type": "integer",
"format": "int64"
},
"description": {
"type": "string",
},
},
"required": [
"amount",
"description"
],
},
]
})
);

// Multiple flatten fields, with field that contain flatten as well.
// Record contain Metadata that is flatten as well, but it doesn't matter
// here as the generated spec will reference to Record directly.
let value: Value = api_doc! {
#[derive(Serialize)]
struct NamedFields {
id: &'static str,
#[serde(flatten)]
record: Record,
#[serde(flatten)]
pagination: Pagination
}
};

assert_json_eq!(
value,
json!({
"allOf": [
{
"$ref": "#/components/schemas/Record"
},
{
"$ref": "#/components/schemas/Pagination"
},
{
"type": "object",
"properties": {
"id": {
"type": "string",
},
},
"required": [
"id",
],
},
]
})
);
}

#[test]
fn derive_complex_enum_serde_tag_title() {
#[derive(Serialize)]
Expand Down

0 comments on commit 0332914

Please sign in to comment.