diff --git a/pin-project-internal/src/lib.rs b/pin-project-internal/src/lib.rs index 4bdd0ed7..eaa2d782 100644 --- a/pin-project-internal/src/lib.rs +++ b/pin-project-internal/src/lib.rs @@ -289,7 +289,7 @@ pub fn pin_project(args: TokenStream, input: TokenStream) -> TokenStream { /// `Pin<&mut Self>`. In particular, it will never be called more than once, /// just like [`Drop::drop`]. /// -/// Example: +/// ## Example /// /// ```rust /// use pin_project::{pin_project, pinned_drop}; @@ -325,14 +325,55 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// *This attribute is available if pin-project is built with the /// `"project_attr"` feature.* /// -/// The attribute at the expression position is not stable, so you need to use -/// a dummy `#[project]` attribute for the function. +/// The following three syntaxes are supported. /// -/// ## Examples +/// ## `impl` blocks +/// +/// All methods (and associated functions) in `#[project] impl` block become +/// methods of the projected type. If you want to implement methods on the +/// original type, you need to create another (non-`#[project]`) `impl` block. +/// +/// To call a method implemented in `#[project] impl` block, you need to first +/// get the projected-type with `let this = self.project();`. +/// +/// ### Examples +/// +/// ```rust +/// use pin_project::{pin_project, project}; +/// use std::pin::Pin; +/// +/// #[pin_project] +/// struct Foo { +/// #[pin] +/// future: T, +/// field: U, +/// } +/// +/// // impl for the original type +/// impl Foo { +/// fn bar(mut self: Pin<&mut Self>) { +/// self.project().baz() +/// } +/// } +/// +/// // impl for the projected type +/// #[project] +/// impl Foo { +/// fn baz(self) { +/// let Self { future, field } = self; +/// +/// let _: Pin<&mut T> = future; +/// let _: &mut U = field; +/// } +/// } +/// ``` +/// +/// ## `let` bindings /// -/// The following two syntaxes are supported. +/// *The attribute at the expression position is not stable, so you need to use +/// a dummy `#[project]` attribute for the function.* /// -/// ### `let` bindings +/// ### Examples /// /// ```rust /// use pin_project::{pin_project, project}; @@ -357,7 +398,12 @@ pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { /// } /// ``` /// -/// ### `match` expressions +/// ## `match` expressions +/// +/// *The attribute at the expression position is not stable, so you need to use +/// a dummy `#[project]` attribute for the function.* +/// +/// ### Examples /// /// ```rust /// use pin_project::{project, pin_project}; diff --git a/pin-project-internal/src/pin_project/mod.rs b/pin-project-internal/src/pin_project/mod.rs index 80fdb902..481acc3e 100644 --- a/pin-project-internal/src/pin_project/mod.rs +++ b/pin-project-internal/src/pin_project/mod.rs @@ -7,7 +7,7 @@ use syn::{ *, }; -use crate::utils::{crate_path, proj_ident, proj_trait_ident}; +use crate::utils::{self, crate_path, proj_ident, proj_trait_ident}; mod enums; mod structs; @@ -208,34 +208,14 @@ fn ensure_not_packed(item: &ItemStruct) -> Result { /// Determine the lifetime names. Ensure it doesn't overlap with any existing lifetime names. fn proj_lifetime(generics: &Punctuated) -> Lifetime { let mut lifetime_name = String::from("'_pin"); - let existing_lifetimes: Vec = generics - .iter() - .filter_map(|param| { - if let GenericParam::Lifetime(LifetimeDef { lifetime, .. }) = param { - Some(lifetime.to_string()) - } else { - None - } - }) - .collect(); - while existing_lifetimes.iter().any(|name| *name == lifetime_name) { - lifetime_name.push('_'); - } + utils::proj_lifetime_name(&mut lifetime_name, generics); Lifetime::new(&lifetime_name, Span::call_site()) } /// Makes the generics of projected type from the reference of the original generics. fn proj_generics(generics: &Generics, lifetime: &Lifetime) -> Generics { let mut generics = generics.clone(); - generics.params.insert( - 0, - GenericParam::Lifetime(LifetimeDef { - attrs: Vec::new(), - lifetime: lifetime.clone(), - colon_token: None, - bounds: Punctuated::new(), - }), - ); + utils::proj_generics(&mut generics, lifetime.clone()); generics } diff --git a/pin-project-internal/src/pinned_drop.rs b/pin-project-internal/src/pinned_drop.rs index f174df7a..f867736e 100644 --- a/pin-project-internal/src/pinned_drop.rs +++ b/pin-project-internal/src/pinned_drop.rs @@ -31,7 +31,7 @@ fn parse_arg(arg: &FnArg) -> Result<&Type> { } } - Err(error!(&arg, "#[pinned_drop] function must take a argument `Pin<&mut Type>`")) + Err(error!(arg, "#[pinned_drop] function must take a argument `Pin<&mut Type>`")) } fn parse(input: TokenStream) -> Result { diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 679ddd44..bcde4882 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -1,4 +1,4 @@ -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::ToTokens; use syn::{ parse::Nothing, @@ -8,7 +8,7 @@ use syn::{ *, }; -use crate::utils::{proj_ident, VecExt}; +use crate::utils::{proj_generics, proj_ident, proj_lifetime_name, VecExt}; /// The attribute name. const NAME: &str = "project"; @@ -25,6 +25,7 @@ fn parse(input: TokenStream) -> Result { } Stmt::Local(local) => local.replace(&mut Register::default()), Stmt::Item(Item::Fn(ItemFn { block, .. })) => Dummy.visit_block_mut(block), + Stmt::Item(Item::Impl(item)) => item.replace(&mut Register::default()), _ => {} } @@ -36,6 +37,37 @@ trait Replace { fn replace(&mut self, register: &mut Register); } +impl Replace for ItemImpl { + fn replace(&mut self, _: &mut Register) { + let PathSegment { ident, arguments } = match &mut *self.self_ty { + Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(), + _ => return, + }; + + replace_ident(ident); + + let mut lifetime_name = String::from("'_pin"); + proj_lifetime_name(&mut lifetime_name, &self.generics.params); + self.items + .iter_mut() + .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None }) + .for_each(|item| proj_lifetime_name(&mut lifetime_name, &item.sig.generics.params)); + let lifetime = Lifetime::new(&lifetime_name, Span::call_site()); + + proj_generics(&mut self.generics, syn::parse_quote!(#lifetime)); + + match arguments { + PathArguments::None => { + *arguments = PathArguments::AngleBracketed(syn::parse_quote!(<#lifetime>)); + } + PathArguments::AngleBracketed(args) => { + args.args.insert(0, syn::parse_quote!(#lifetime)); + } + PathArguments::Parenthesized(_) => unreachable!(), + } + } +} + impl Replace for Local { fn replace(&mut self, register: &mut Register) { self.pat.replace(register); @@ -83,11 +115,15 @@ impl Replace for Path { if register.0.is_none() || register.eq(&self.segments[0].ident, len) { register.update(&self.segments[0].ident, len); - self.segments[0].ident = proj_ident(&self.segments[0].ident) + replace_ident(&mut self.segments[0].ident); } } } +fn replace_ident(ident: &mut Ident) { + *ident = proj_ident(ident); +} + #[derive(Default)] struct Register(Option<(Ident, usize)>); diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index 8c482415..26465ac1 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -1,6 +1,9 @@ -use proc_macro2::Ident; use quote::format_ident; -use syn::Attribute; +use syn::{ + punctuated::Punctuated, + token::{self, Comma}, + Attribute, GenericParam, Generics, Ident, Lifetime, LifetimeDef, +}; /// Makes the ident of projected type from the reference of the original ident. pub(crate) fn proj_ident(ident: &Ident) -> Ident { @@ -11,6 +14,46 @@ pub(crate) fn proj_trait_ident(ident: &Ident) -> Ident { format_ident!("__{}ProjectionTrait", ident) } +/// Determine the lifetime names. Ensure it doesn't overlap with any existing lifetime names. +pub(crate) fn proj_lifetime_name( + lifetime_name: &mut String, + generics: &Punctuated, +) { + let existing_lifetimes: Vec = generics + .iter() + .filter_map(|param| { + if let GenericParam::Lifetime(LifetimeDef { lifetime, .. }) = param { + Some(lifetime.to_string()) + } else { + None + } + }) + .collect(); + while existing_lifetimes.iter().any(|name| name.starts_with(&**lifetime_name)) { + lifetime_name.push('_'); + } +} + +/// Makes the generics of projected type from the reference of the original generics. +pub(crate) fn proj_generics(generics: &mut Generics, lifetime: Lifetime) { + if let lt_token @ None = &mut generics.lt_token { + *lt_token = Some(token::Lt::default()) + } + if let gt_token @ None = &mut generics.gt_token { + *gt_token = Some(token::Gt::default()) + } + + generics.params.insert( + 0, + GenericParam::Lifetime(LifetimeDef { + attrs: Vec::new(), + lifetime, + colon_token: None, + bounds: Punctuated::new(), + }), + ); +} + pub(crate) trait VecExt { fn find_remove(&mut self, ident: &str) -> Option; } @@ -48,7 +91,7 @@ pub(crate) fn crate_path() -> Ident { macro_rules! error { ($span:expr, $msg:expr) => { - syn::Error::new_spanned($span, $msg) + syn::Error::new_spanned(&$span, $msg) }; ($span:expr, $($tt:tt)*) => { error!($span, format!($($tt)*)) diff --git a/tests/project.rs b/tests/project.rs index d61d5626..dcac1cd0 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -10,7 +10,7 @@ use pin_project::{pin_project, project}; #[project] // Nightly does not need a dummy attribute to the function. #[test] -fn test_project_attr() { +fn project_stmt_expr() { // struct #[pin_project] @@ -85,7 +85,7 @@ fn test_project_attr() { } #[test] -fn test_project_attr_nightly() { +fn project_stmt_expr_nightly() { // enum #[pin_project] @@ -136,3 +136,65 @@ fn test_project_attr_nightly() { Baz::None => {} }; } + +#[test] +fn project_impl() { + #[pin_project] + struct HasGenerics { + #[pin] + field1: T, + field2: U, + } + + #[project] + impl HasGenerics { + fn a(self) { + let Self { field1, field2 } = self; + + let _x: Pin<&mut T> = field1; + let _y: &mut U = field2; + } + } + + #[pin_project] + struct NoneGenerics { + #[pin] + field1: i32, + field2: u32, + } + + #[project] + impl NoneGenerics {} + + #[pin_project] + struct HasLifetimes<'a, T, U> { + #[pin] + field1: &'a mut T, + field2: U, + } + + #[project] + impl HasLifetimes<'_, T, U> {} + + #[pin_project] + struct HasOverlappingLifetimes<'_pin, T, U> { + #[pin] + field1: &'_pin mut T, + field2: U, + } + + #[project] + impl<'_pin, T, U> HasOverlappingLifetimes<'_pin, T, U> {} + + #[pin_project] + struct HasOverlappingLifetimes2 { + #[pin] + field1: T, + field2: U, + } + + #[project] + impl HasOverlappingLifetimes2 { + fn foo<'_pin>(&'_pin self) {} + } +}