Skip to content

Commit

Permalink
Only put Debug-like bounds on type variables (#371, #363)
Browse files Browse the repository at this point in the history
## Synopsis

The problem, as reported in the issue, is that code like the following

```rust
#[derive(derive_more::Debug)]
struct Item {
    next: Option<Box<Item>>,
}
```

expands into something like

```rust
impl std::fmt::Debug for Item where Item: Debug { /* ... */ }
```

which does not compile. This PR changes the Debug derive so it does not
emit those bounds.

## Solution

My understanding of the current code is that we iterate over all fields
of the struct/enum and add either a specific
format bound (e.g. `: fmt::Binary`), a default `: fmt::Debug` bound or
skip it if either it is marked
as `#[debug(skip)]` or the entire container has a format attribute. 

The suggested solution in the issue (if I understood it correctly) was
to only add bounds if the type is a type
variable, since rustc already knows if a concrete type is, say, `:
fmt::Debug`. So, instead of adding the bound for
every type, we first check that the type contains one of the container's
type variables. Since types can be nested, it
is an unfortunately long recursive function handling the different types
of types. This part of Rust syntax is probably
not going to change, so perhaps it is feasible to shorten some of the
branches into `_ => false`.

One drawback of this implementation is that we iterate over the list of
type variables every time we find a "leaf type".
I chose `Vec` over `HashSet` because in my experience there are only a
handful of type variables per container.

Co-authored-by: Jelte Fennema-Nio <github-tech@jeltef.nl>
Co-authored-by: Kai Ren <tyranron@gmail.com>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent af823ea commit 162535e
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 6 deletions.
143 changes: 137 additions & 6 deletions impl/src/fmt/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,21 @@ pub fn expand(input: &syn::DeriveInput, _: &str) -> syn::Result<TokenStream> {
.unwrap_or_default();
let ident = &input.ident;

let type_params: Vec<_> = input
.generics
.params
.iter()
.filter_map(|p| match p {
syn::GenericParam::Type(t) => Some(&t.ident),
syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None,
})
.collect();

let (bounds, body) = match &input.data {
syn::Data::Struct(s) => expand_struct(attrs, ident, s, &attr_name),
syn::Data::Enum(e) => expand_enum(attrs, e, &attr_name),
syn::Data::Struct(s) => {
expand_struct(attrs, ident, s, &type_params, &attr_name)
}
syn::Data::Enum(e) => expand_enum(attrs, e, &type_params, &attr_name),
syn::Data::Union(_) => {
return Err(syn::Error::new(
input.span(),
Expand Down Expand Up @@ -64,11 +76,13 @@ fn expand_struct(
attrs: ContainerAttributes,
ident: &Ident,
s: &syn::DataStruct,
type_params: &[&syn::Ident],
attr_name: &syn::Ident,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let s = Expansion {
attr: &attrs,
fields: &s.fields,
type_params,
ident,
attr_name,
};
Expand Down Expand Up @@ -99,6 +113,7 @@ fn expand_struct(
fn expand_enum(
mut attrs: ContainerAttributes,
e: &syn::DataEnum,
type_params: &[&syn::Ident],
attr_name: &syn::Ident,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
if let Some(enum_fmt) = attrs.fmt.as_ref() {
Expand Down Expand Up @@ -136,6 +151,7 @@ fn expand_enum(
let v = Expansion {
attr: &attrs,
fields: &variant.fields,
type_params,
ident,
attr_name,
};
Expand Down Expand Up @@ -195,6 +211,9 @@ struct Expansion<'a> {
/// Struct or enum [`syn::Fields`].
fields: &'a syn::Fields,

/// Type parameters in this struct or enum.
type_params: &'a [&'a syn::Ident],

/// Name of the attributes, considered by this macro.
attr_name: &'a syn::Ident,
}
Expand Down Expand Up @@ -334,15 +353,26 @@ impl<'a> Expansion<'a> {
let mut out = self.attr.bounds.0.clone().into_iter().collect::<Vec<_>>();

if let Some(fmt) = self.attr.fmt.as_ref() {
out.extend(fmt.bounded_types(self.fields).map(|(ty, trait_name)| {
let trait_ident = format_ident!("{trait_name}");
out.extend(fmt.bounded_types(self.fields).filter_map(
|(ty, trait_name)| {
if !self.contains_generic_param(ty) {
return None;
}

let trait_ident = format_ident!("{trait_name}");

parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
}));
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
},
));
Ok(out)
} else {
self.fields.iter().try_fold(out, |mut out, field| {
let ty = &field.ty;

if !self.contains_generic_param(ty) {
return Ok(out);
}

match FieldAttribute::parse_attrs(&field.attrs, self.attr_name)?
.map(Spanning::into_inner)
{
Expand All @@ -362,4 +392,105 @@ impl<'a> Expansion<'a> {
})
}
}

/// Checks whether the provided [`syn::Path`] contains any of these [`Expansion::type_params`].
fn path_contains_generic_param(&self, path: &syn::Path) -> bool {
path.segments
.iter()
.any(|segment| match &segment.arguments {
syn::PathArguments::None => false,
syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments { args, .. },
) => args.iter().any(|generic| match generic {
syn::GenericArgument::Type(ty)
| syn::GenericArgument::AssocType(syn::AssocType { ty, .. }) => {
self.contains_generic_param(ty)
}

syn::GenericArgument::Lifetime(_)
| syn::GenericArgument::Const(_)
| syn::GenericArgument::AssocConst(_)
| syn::GenericArgument::Constraint(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}),
syn::PathArguments::Parenthesized(
syn::ParenthesizedGenericArguments { inputs, output, .. },
) => {
inputs.iter().any(|ty| self.contains_generic_param(ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => {
self.contains_generic_param(ty)
}
}
}
})
}

/// Checks whether the provided [`syn::Type`] contains any of these [`Expansion::type_params`].
fn contains_generic_param(&self, ty: &syn::Type) -> bool {
if self.type_params.is_empty() {
return false;
}
match ty {
syn::Type::Path(syn::TypePath { qself, path }) => {
if let Some(qself) = qself {
if self.contains_generic_param(&qself.ty) {
return true;
}
}

if let Some(ident) = path.get_ident() {
self.type_params.iter().any(|param| *param == ident)
} else {
self.path_contains_generic_param(path)
}
}

syn::Type::Array(syn::TypeArray { elem, .. })
| syn::Type::Group(syn::TypeGroup { elem, .. })
| syn::Type::Paren(syn::TypeParen { elem, .. })
| syn::Type::Ptr(syn::TypePtr { elem, .. })
| syn::Type::Reference(syn::TypeReference { elem, .. })
| syn::Type::Slice(syn::TypeSlice { elem, .. }) => {
self.contains_generic_param(elem)
}

syn::Type::BareFn(syn::TypeBareFn { inputs, output, .. }) => {
inputs
.iter()
.any(|arg| self.contains_generic_param(&arg.ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => self.contains_generic_param(ty),
}
}
syn::Type::Tuple(syn::TypeTuple { elems, .. }) => {
elems.iter().any(|ty| self.contains_generic_param(ty))
}

syn::Type::ImplTrait(_) => false,
syn::Type::Infer(_) => false,
syn::Type::Macro(_) => false,
syn::Type::Never(_) => false,
syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) => {
bounds.iter().any(|bound| match bound {
syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => {
self.path_contains_generic_param(path)
}
syn::TypeParamBound::Lifetime(_) => false,
syn::TypeParamBound::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
})
}
syn::Type::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}
}
}
138 changes: 138 additions & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1932,3 +1932,141 @@ mod complex_enum_syntax {
assert_eq!(format!("{:?}", Enum::A), "A");
}
}

// See: https://github.com/JelteF/derive_more/issues/363
mod type_variables {
mod our_alloc {
#[cfg(not(feature = "std"))]
pub use alloc::{boxed::Box, format, vec, vec::Vec};
#[cfg(feature = "std")]
pub use std::{boxed::Box, format, vec, vec::Vec};
}

use our_alloc::{format, vec, Box, Vec};

use derive_more::Debug;

#[derive(Debug)]
struct ItemStruct {
next: Option<Box<ItemStruct>>,
}

#[derive(Debug)]
struct ItemTuple(Option<Box<ItemTuple>>);

#[derive(Debug)]
#[debug("Item({_0:?})")]
struct ItemTupleContainerFmt(Option<Box<ItemTupleContainerFmt>>);

#[derive(Debug)]
enum ItemEnum {
Node { children: Vec<ItemEnum>, inner: i32 },
Leaf { inner: i32 },
}

#[derive(Debug)]
struct VecMeansDifferent<Vec> {
next: our_alloc::Vec<i32>,
real: Vec,
}

#[derive(Debug)]
struct Array<T> {
#[debug("{t}")]
t: [T; 10],
}

mod parens {
#![allow(unused_parens)] // test that type is found even in parentheses

use derive_more::Debug;

#[derive(Debug)]
struct Paren<T> {
t: (T),
}
}

#[derive(Debug)]
struct ParenthesizedGenericArgumentsInput<T> {
t: dyn Fn(T) -> i32,
}

#[derive(Debug)]
struct ParenthesizedGenericArgumentsOutput<T> {
t: dyn Fn(i32) -> T,
}

#[derive(Debug)]
struct Ptr<T> {
t: *const T,
}

#[derive(Debug)]
struct Reference<'a, T> {
t: &'a T,
}

#[derive(Debug)]
struct Slice<'a, T> {
t: &'a [T],
}

#[derive(Debug)]
struct BareFn<T> {
t: Box<fn(T) -> T>,
}

#[derive(Debug)]
struct Tuple<T> {
t: Box<(T, T)>,
}

trait MyTrait<T> {}

#[derive(Debug)]
struct TraitObject<T> {
t: Box<dyn MyTrait<T>>,
}

#[test]
fn assert() {
assert_eq!(
format!(
"{:?}",
ItemStruct {
next: Some(Box::new(ItemStruct { next: None }))
},
),
"ItemStruct { next: Some(ItemStruct { next: None }) }",
);

assert_eq!(
format!("{:?}", ItemTuple(Some(Box::new(ItemTuple(None))))),
"ItemTuple(Some(ItemTuple(None)))",
);

assert_eq!(
format!(
"{:?}",
ItemTupleContainerFmt(Some(Box::new(ItemTupleContainerFmt(None)))),
),
"Item(Some(Item(None)))",
);

let item = ItemEnum::Node {
children: vec![
ItemEnum::Node {
children: vec![],
inner: 0,
},
ItemEnum::Leaf { inner: 1 },
],
inner: 2,
};
assert_eq!(
format!("{item:?}"),
"Node { children: [Node { children: [], inner: 0 }, Leaf { inner: 1 }], inner: 2 }",
)
}
}
1 change: 1 addition & 0 deletions tests/sum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![allow(dead_code)] // some code is tested for type checking only

use derive_more::Sum;

Expand Down

0 comments on commit 162535e

Please sign in to comment.