diff --git a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs index bbc7cd3962720..ff149e9af24ab 100644 --- a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs +++ b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs @@ -1,21 +1,30 @@ use std::mem::swap; use ast::HasAttrs; +use rustc_ast::mut_visit::MutVisitor; +use rustc_ast::visit::BoundKind; use rustc_ast::{ self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem, TraitBoundModifiers, VariantData, }; use rustc_attr as attr; +use rustc_data_structures::flat_map_in_place::FlatMapInPlace; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{sym, Ident}; -use rustc_span::Span; +use rustc_span::{Span, Symbol}; use smallvec::{smallvec, SmallVec}; use thin_vec::{thin_vec, ThinVec}; +type AstTy = ast::ptr::P; + macro_rules! path { ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] } } +macro_rules! symbols { + ($($part:ident)::*) => { [$(sym::$part),*] } +} + pub fn expand_deriving_smart_ptr( cx: &ExtCtxt<'_>, span: Span, @@ -143,8 +152,11 @@ pub fn expand_deriving_smart_ptr( // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. let mut impl_generics = generics.clone(); + let pointee_ty_ident = generics.params[pointee_param_idx].ident; + let mut self_bounds; { let p = &mut impl_generics.params[pointee_param_idx]; + self_bounds = p.bounds.clone(); let arg = GenericArg::Type(s_ty.clone()); let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]); p.bounds.push(cx.trait_bound(unsize, false)); @@ -152,22 +164,163 @@ pub fn expand_deriving_smart_ptr( swap(&mut p.attrs, &mut attrs); p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); } + // We should not set default values to constant generic parameters + // and write out bounds that indirectly involves `#[pointee]`. + for (idx, (params, orig_params)) in + impl_generics.params.iter_mut().zip(&generics.params).enumerate() + { + if idx == pointee_param_idx { + continue; + } + match &mut params.kind { + ast::GenericParamKind::Const { default, .. } => *default = None, + ast::GenericParamKind::Type { default } => *default = None, + ast::GenericParamKind::Lifetime => {} + } + for bound in &orig_params.bounds { + let mut bound = bound.clone(); + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: false, + }; + substitution.visit_param_bound(&mut bound, BoundKind::Bound); + if substitution.rewritten { + params.bounds.push(bound); + } + } + } // Add the `__S: ?Sized` extra parameter to the impl block. + // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`. let sized = cx.path_global(span, path!(span, core::marker::Sized)); - let bound = GenericBound::Trait( - cx.poly_trait_ref(span, sized), - TraitBoundModifiers { - polarity: ast::BoundPolarity::Maybe(span), - constness: ast::BoundConstness::Never, - asyncness: ast::BoundAsyncness::Normal, - }, - ); - let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None); - impl_generics.params.push(extra_param); + if self_bounds.iter().all(|bound| { + if let GenericBound::Trait( + trait_ref, + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, + ) = bound + { + !is_sized_marker(&trait_ref.trait_ref.path) + } else { + false + } + }) { + self_bounds.push(GenericBound::Trait( + cx.poly_trait_ref(span, sized), + TraitBoundModifiers { + polarity: ast::BoundPolarity::Maybe(span), + constness: ast::BoundConstness::Never, + asyncness: ast::BoundAsyncness::Normal, + }, + )); + } + { + let mut substitution = + TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false }; + for bound in &mut self_bounds { + substitution.visit_param_bound(bound, BoundKind::Bound); + } + } + + // We should also commute the where bounds from `#[pointee]` to `__S` + // as well as any bound that indirectly involves the `#[pointee]` type. + for bound in &generics.where_clause.predicates { + if let ast::WherePredicate::BoundPredicate(bound) = bound { + let bound_on_pointee = bound + .bounded_ty + .kind + .is_simple_path() + .map_or(false, |name| name == pointee_ty_ident.name); + + let bounds: Vec<_> = bound + .bounds + .iter() + .filter(|bound| { + if let GenericBound::Trait( + trait_ref, + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, + ) = bound + { + !bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path) + } else { + true + } + }) + .cloned() + .collect(); + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: bounds.len() != bound.bounds.len(), + }; + let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate { + span: bound.span, + bound_generic_params: bound.bound_generic_params.clone(), + bounded_ty: bound.bounded_ty.clone(), + bounds, + }); + substitution.visit_where_predicate(&mut predicate); + if substitution.rewritten { + impl_generics.where_clause.predicates.push(predicate); + } + } + } + + let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None); + impl_generics.params.insert(pointee_param_idx + 1, extra_param); // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`. let gen_args = vec![GenericArg::Type(alt_self_type.clone())]; add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone()); add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone()); } + +fn is_sized_marker(path: &ast::Path) -> bool { + const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized); + const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized); + if path.segments.len() == 3 { + path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol) + || path + .segments + .iter() + .zip(STD_UNSIZE) + .all(|(segment, symbol)| segment.ident.name == symbol) + } else { + *path == sym::Sized + } +} + +struct TypeSubstitution<'a> { + from_name: Symbol, + to_ty: &'a AstTy, + rewritten: bool, +} + +impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> { + fn visit_ty(&mut self, ty: &mut AstTy) { + if let Some(name) = ty.kind.is_simple_path() + && name == self.from_name + { + *ty = self.to_ty.clone(); + self.rewritten = true; + } else { + ast::mut_visit::walk_ty(self, ty); + } + } + + fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) { + match where_predicate { + rustc_ast::WherePredicate::BoundPredicate(bound) => { + bound + .bound_generic_params + .flat_map_in_place(|param| self.flat_map_generic_param(param)); + self.visit_ty(&mut bound.bounded_ty); + for bound in &mut bound.bounds { + self.visit_param_bound(bound, BoundKind::Bound) + } + } + rustc_ast::WherePredicate::RegionPredicate(_) + | rustc_ast::WherePredicate::EqPredicate(_) => {} + } + } +} diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.rs b/tests/ui/deriving/deriving-smart-pointer-expanded.rs new file mode 100644 index 0000000000000..33f316ecdf07a --- /dev/null +++ b/tests/ui/deriving/deriving-smart-pointer-expanded.rs @@ -0,0 +1,22 @@ +//@ check-pass +//@ compile-flags: -Zunpretty=expanded +#![feature(derive_smart_pointer)] +use std::marker::SmartPointer; + +pub trait MyTrait {} + +#[derive(SmartPointer)] +#[repr(transparent)] +struct MyPointer<'a, #[pointee] T: ?Sized> { + ptr: &'a T, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct MyPointer2<'a, Y, Z: MyTrait, #[pointee] T: ?Sized + MyTrait, X: MyTrait> +where + Y: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.stdout b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout new file mode 100644 index 0000000000000..c0881df76a585 --- /dev/null +++ b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout @@ -0,0 +1,44 @@ +#![feature(prelude_import)] +#![no_std] +//@ check-pass +//@ compile-flags: -Zunpretty=expanded +#![feature(derive_smart_pointer)] +#[prelude_import] +use ::std::prelude::rust_2015::*; +#[macro_use] +extern crate std; +use std::marker::SmartPointer; + +pub trait MyTrait {} + +#[repr(transparent)] +struct MyPointer<'a, #[pointee] T: ?Sized> { + ptr: &'a T, +} +#[automatically_derived] +impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized> + ::core::ops::DispatchFromDyn> for MyPointer<'a, T> { +} +#[automatically_derived] +impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized> + ::core::ops::CoerceUnsized> for MyPointer<'a, T> { +} + +#[repr(transparent)] +pub struct MyPointer2<'a, Y, Z: MyTrait, #[pointee] T: ?Sized + MyTrait, + X: MyTrait> where Y: MyTrait { + data: &'a mut T, + x: core::marker::PhantomData, +} +#[automatically_derived] +impl<'a, Y, Z: MyTrait + MyTrait<__S>, T: ?Sized + MyTrait + + ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait + + MyTrait<__S>> ::core::ops::DispatchFromDyn> + for MyPointer2<'a, Y, Z, T, X> where Y: MyTrait, Y: MyTrait<__S> { +} +#[automatically_derived] +impl<'a, Y, Z: MyTrait + MyTrait<__S>, T: ?Sized + MyTrait + + ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait + + MyTrait<__S>> ::core::ops::CoerceUnsized> for + MyPointer2<'a, Y, Z, T, X> where Y: MyTrait, Y: MyTrait<__S> { +} diff --git a/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs new file mode 100644 index 0000000000000..0d4f765aade2c --- /dev/null +++ b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs @@ -0,0 +1,78 @@ +//@ check-pass + +#![feature(derive_smart_pointer)] + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait OnDrop { + fn on_drop(&mut self); +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr2<'a, #[pointee] T: ?Sized, X> +where + T: OnDrop, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait MyTrait {} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr3<'a, #[pointee] T: ?Sized, X> +where + T: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr4<'a, #[pointee] T: MyTrait + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr5<'a, #[pointee] T: ?Sized, X> +where + Ptr5Companion: MyTrait, + Ptr5Companion2: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub struct Ptr5Companion(core::marker::PhantomData); +pub struct Ptr5Companion2; + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/ +#[repr(transparent)] +#[derive(core::marker::SmartPointer)] +pub struct ListArc<#[pointee] T, const ID: u64 = 0> +where + T: ListArcSafe + ?Sized, +{ + arc: *const T, +} + +pub trait ListArcSafe {} + +fn main() {}