diff --git a/utoipa-gen/src/component/schema.rs b/utoipa-gen/src/component/schema.rs index 74087381..45315016 100644 --- a/utoipa-gen/src/component/schema.rs +++ b/utoipa-gen/src/component/schema.rs @@ -194,14 +194,26 @@ struct NamedStructSchema<'a> { impl ToTokens for NamedStructSchema<'_> { fn to_tokens(&self, tokens: &mut TokenStream) { let container_rules = serde::parse_container(self.attributes); + let mut object_tokens = quote! { utoipa::openapi::ObjectBuilder::new() }; - tokens.extend(quote! { utoipa::openapi::ObjectBuilder::new() }); + let flatten_fields: Vec<&Field> = self + .fields + .iter() + .filter(|field| { + let field_rule = serde::parse_value(&field.attrs); + is_flatten(&field_rule) + }) + .collect(); self.fields .iter() .filter_map(|field| { let field_rule = serde::parse_value(&field.attrs); + if is_flatten(&field_rule) { + return None; + }; + if is_not_skipped(&field_rule) { Some((field, field_rule)) } else { @@ -218,53 +230,36 @@ impl ToTokens for NamedStructSchema<'_> { let name = &rename_field(&container_rules, &mut field_rule, field_name) .unwrap_or_else(|| String::from(field_name)); - let type_tree = &mut TypeTree::from_type(&field.ty); - - if let Some((generic_types, alias)) = self.generics.zip(self.alias) { - generic_types - .type_params() - .enumerate() - .for_each(|(index, generic)| { - if let Some(generic_type) = type_tree.find_mut_by_ident(&generic.ident) - { - generic_type.update_path( - &alias.generics.type_params().nth(index).unwrap().ident, - ); - }; - }) - } - - let deprecated = super::get_deprecated(&field.attrs); - let attrs = - SchemaAttr::::from_attributes_validated(&field.attrs, type_tree); - - let override_type_tree = attrs - .as_ref() - .and_then(|field| field.as_ref().value_type.as_ref().map(TypeTree::from_type)); + with_field_as_schema_property(self, field, |schema_property| { + object_tokens.extend(quote! { + .property(#name, #schema_property) + }); - let xml_value = attrs - .as_ref() - .and_then(|named_field| named_field.as_ref().xml.as_ref()); - let comments = CommentAttributes::from_attributes(&field.attrs); + if !schema_property.is_option() && !is_default(&container_rules, &field_rule) { + object_tokens.extend(quote! { + .required(#name) + }) + } + }) + }); - let schema_property = SchemaProperty::new( - override_type_tree.as_ref().unwrap_or(type_tree), - Some(&comments), - attrs.as_ref(), - deprecated.as_ref(), - xml_value, - ); + if !flatten_fields.is_empty() { + tokens.extend(quote! { + utoipa::openapi::AllOfBuilder::new() + }); - tokens.extend(quote! { - .property(#name, #schema_property) - }); + for field in flatten_fields { + with_field_as_schema_property(self, field, |schema_property| { + tokens.extend(quote! { .item(#schema_property) }); + }) + } - if !schema_property.is_option() && !is_default(&container_rules, &field_rule) { - tokens.extend(quote! { - .required(#name) - }) - } - }); + tokens.extend(quote! { + .item(#object_tokens) + }) + } else { + tokens.extend(object_tokens) + } if let Some(deprecated) = super::get_deprecated(self.attributes) { tokens.extend(quote! { .deprecated(Some(#deprecated)) }); @@ -283,6 +278,44 @@ impl ToTokens for NamedStructSchema<'_> { } } +fn with_field_as_schema_property( + schema: &NamedStructSchema, + field: &Field, + yield_: impl FnOnce(SchemaProperty<'_, NamedField<'_>>) -> R, +) -> R { + let type_tree = &mut TypeTree::from_type(&field.ty); + + if let Some((generic_types, alias)) = schema.generics.zip(schema.alias) { + generic_types + .type_params() + .enumerate() + .for_each(|(index, generic)| { + if let Some(generic_type) = type_tree.find_mut_by_ident(&generic.ident) { + generic_type + .update_path(&alias.generics.type_params().nth(index).unwrap().ident); + }; + }) + } + + let deprecated = super::get_deprecated(&field.attrs); + let attrs = SchemaAttr::::from_attributes_validated(&field.attrs, type_tree); + let override_type_tree = attrs + .as_ref() + .and_then(|field| field.as_ref().value_type.as_ref().map(TypeTree::from_type)); + let xml_value = attrs + .as_ref() + .and_then(|named_field| named_field.as_ref().xml.as_ref()); + let comments = CommentAttributes::from_attributes(&field.attrs); + + yield_(SchemaProperty::new( + override_type_tree.as_ref().unwrap_or(type_tree), + Some(&comments), + attrs.as_ref(), + deprecated.as_ref(), + xml_value, + )) +} + #[inline] fn is_default(container_rules: &Option, field_rule: &Option) -> bool { *container_rules @@ -1154,6 +1187,13 @@ fn is_not_skipped(rule: &Option) -> bool { .unwrap_or(true) } +#[inline] +fn is_flatten(rule: &Option) -> bool { + rule.as_ref() + .map(|value| value.flatten.is_some()) + .unwrap_or(false) +} + /// Resolves the appropriate [`RenameRule`] to apply to the specified `struct` `field` name given a /// `container_rule` (`struct` or `enum` level) and `field_rule` (`struct` field or `enum` variant /// level). Returns `Some` of the result of the `rename_op` if a rename is required by the supplied diff --git a/utoipa-gen/src/component/serde.rs b/utoipa-gen/src/component/serde.rs index 704978fd..2baba76a 100644 --- a/utoipa-gen/src/component/serde.rs +++ b/utoipa-gen/src/component/serde.rs @@ -26,6 +26,7 @@ pub struct SerdeValue { pub skip: Option, pub rename: Option, pub default: Option, + pub flatten: Option, } impl SerdeValue { @@ -37,6 +38,7 @@ impl SerdeValue { while let Some((tt, next)) = rest.token_tree() { match tt { TokenTree::Ident(ident) if ident == "skip" => value.skip = Some(true), + TokenTree::Ident(ident) if ident == "flatten" => value.flatten = Some(true), TokenTree::Ident(ident) if ident == "rename" => { if let Some((literal, _)) = parse_next_lit_str(next) { value.rename = Some(literal) diff --git a/utoipa-gen/src/lib.rs b/utoipa-gen/src/lib.rs index 380c81f0..9e8bcfec 100644 --- a/utoipa-gen/src/lib.rs +++ b/utoipa-gen/src/lib.rs @@ -125,6 +125,7 @@ use ext::ArgumentResolver; /// * `skip = "..."` Supported **only** in field or variant level. /// * `tag = "..."` Supported in container level. `tag` attribute also works as a [discriminator field][discriminator] for an enum. /// * `default` Supported in container level and field level according to [serde attributes]. +/// * `flatten` Supported in field level. /// /// Other _`serde`_ attributes works as is but does not have any effect on the generated OpenAPI doc. /// diff --git a/utoipa-gen/tests/schema_derive_test.rs b/utoipa-gen/tests/schema_derive_test.rs index 742953a4..2de688e0 100644 --- a/utoipa-gen/tests/schema_derive_test.rs +++ b/utoipa-gen/tests/schema_derive_test.rs @@ -1284,6 +1284,107 @@ fn derive_complex_enum_serde_tag() { ); } +#[test] +fn derive_serde_flatten() { + #[derive(Serialize)] + struct Metadata { + category: String, + total: u64, + } + + #[derive(Serialize)] + struct Record { + amount: i64, + description: String, + #[serde(flatten)] + metadata: Metadata, + } + + #[derive(Serialize)] + struct Pagination { + page: i64, + next_page: i64, + per_page: i64, + } + + // Single flatten field + let value: Value = api_doc! { + #[derive(Serialize)] + struct Record { + amount: i64, + description: String, + #[serde(flatten)] + metadata: Metadata, + } + }; + + assert_json_eq!( + value, + json!({ + "allOf": [ + { + "$ref": "#/components/schemas/Metadata" + }, + { + "type": "object", + "properties": { + "amount": { + "type": "integer", + "format": "int64" + }, + "description": { + "type": "string", + }, + }, + "required": [ + "amount", + "description" + ], + }, + ] + }) + ); + + // Multiple flatten fields, with field that contain flatten as well. + // Record contain Metadata that is flatten as well, but it doesn't matter + // here as the generated spec will reference to Record directly. + let value: Value = api_doc! { + #[derive(Serialize)] + struct NamedFields { + id: &'static str, + #[serde(flatten)] + record: Record, + #[serde(flatten)] + pagination: Pagination + } + }; + + assert_json_eq!( + value, + json!({ + "allOf": [ + { + "$ref": "#/components/schemas/Record" + }, + { + "$ref": "#/components/schemas/Pagination" + }, + { + "type": "object", + "properties": { + "id": { + "type": "string", + }, + }, + "required": [ + "id", + ], + }, + ] + }) + ); +} + #[test] fn derive_complex_enum_serde_tag_title() { #[derive(Serialize)]