From 2e4a4f367d68b25e936c1b87207f7872bd35bbc2 Mon Sep 17 00:00:00 2001 From: Simon THOBY Date: Wed, 10 Mar 2021 21:11:15 +0100 Subject: [PATCH] attributes: update `#[instrument]` to support `async-trait` 0.1.43+ #1228) It works with both the old and new version of async-trait (except for one doc test that failed previously, and that works with the new version). One nice thing is that the code is simpler (e.g.g no self renaming to _self, which will enable some simplifications in the future). A minor nitpick is that I disliked the deeply nested pattern matching in get_async_trait_kind (previously: get_async_trait_function), so I "flattened" that a bit. Fixes #1219. --- tracing-attributes/Cargo.toml | 5 +- tracing-attributes/src/lib.rs | 402 ++++++++++++++++----------- tracing-attributes/tests/async_fn.rs | 9 +- 3 files changed, 250 insertions(+), 166 deletions(-) diff --git a/tracing-attributes/Cargo.toml b/tracing-attributes/Cargo.toml index d351783e6d..3105ee0935 100644 --- a/tracing-attributes/Cargo.toml +++ b/tracing-attributes/Cargo.toml @@ -34,15 +34,14 @@ proc-macro = true [dependencies] proc-macro2 = "1" -syn = { version = "1", default-features = false, features = ["full", "parsing", "printing", "visit-mut", "clone-impls", "extra-traits", "proc-macro"] } +syn = { version = "1", default-features = false, features = ["full", "parsing", "printing", "visit", "visit-mut", "clone-impls", "extra-traits", "proc-macro"] } quote = "1" - [dev-dependencies] tracing = { path = "../tracing", version = "0.2" } tokio-test = { version = "0.2.0" } tracing-core = { path = "../tracing-core", version = "0.2"} -async-trait = "0.1" +async-trait = "0.1.44" [badges] maintenance = { status = "experimental" } diff --git a/tracing-attributes/src/lib.rs b/tracing-attributes/src/lib.rs index 73e93fd524..54f5efebe5 100644 --- a/tracing-attributes/src/lib.rs +++ b/tracing-attributes/src/lib.rs @@ -74,7 +74,6 @@ patterns_in_fns_without_body, private_in_public, unconditional_recursion, - unused, unused_allocation, unused_comparisons, unused_parens, @@ -89,9 +88,9 @@ use quote::{quote, quote_spanned, ToTokens}; use syn::ext::IdentExt as _; use syn::parse::{Parse, ParseStream}; use syn::{ - punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprCall, FieldPat, FnArg, Ident, Item, - ItemFn, LitInt, LitStr, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, - PatType, Path, Signature, Stmt, Token, + punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, + Ident, Item, ItemFn, LitInt, LitStr, Pat, PatIdent, PatReference, PatStruct, PatTuple, + PatTupleStruct, PatType, Path, Signature, Stmt, Token, TypePath, }; /// Instruments a function to create and enter a `tracing` [span] every time /// the function is called. @@ -221,11 +220,12 @@ use syn::{ /// } /// ``` /// -/// An interesting note on this subject is that references to the `Self` -/// type inside the `fields` argument are only allowed when the instrumented -/// function is a method aka. the function receives `self` as an argument. -/// For example, this *will not work* because it doesn't receive `self`: -/// ```compile_fail +/// Note than on `async-trait` <= 0.1.43, references to the `Self` +/// type inside the `fields` argument were only allowed when the instrumented +/// function is a method (i.e., the function receives `self` as an argument). +/// For example, this *used to not work* because the instrument function +/// didn't receive `self`: +/// ``` /// # use tracing::instrument; /// use async_trait::async_trait; /// @@ -244,7 +244,8 @@ use syn::{ /// } /// ``` /// Instead, you should manually rewrite any `Self` types as the type for -/// which you implement the trait: `#[instrument(fields(tmp = std::any::type_name::()))]`. +/// which you implement the trait: `#[instrument(fields(tmp = std::any::type_name::()))]` +/// (or maybe you can just bump `async-trait`). /// /// [span]: https://docs.rs/tracing/latest/tracing/span/index.html /// [`tracing`]: https://github.com/tokio-rs/tracing @@ -254,30 +255,47 @@ pub fn instrument( args: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let input: ItemFn = syn::parse_macro_input!(item as ItemFn); + let input = syn::parse_macro_input!(item as ItemFn); let args = syn::parse_macro_input!(args as InstrumentArgs); let instrumented_function_name = input.sig.ident.to_string(); - // check for async_trait-like patterns in the block and wrap the - // internal function with Instrument instead of wrapping the - // async_trait generated wrapper + // check for async_trait-like patterns in the block, and instrument + // the future instead of the wrapper if let Some(internal_fun) = get_async_trait_info(&input.block, input.sig.asyncness.is_some()) { // let's rewrite some statements! - let mut stmts: Vec = input.block.stmts.to_vec(); - for stmt in &mut stmts { - if let Stmt::Item(Item::Fn(fun)) = stmt { - // instrument the function if we considered it as the one we truly want to trace - if fun.sig.ident == internal_fun.name { - *stmt = syn::parse2(gen_body( - fun, - args, - instrumented_function_name, - Some(internal_fun), - )) - .unwrap(); - break; + let mut out_stmts = Vec::with_capacity(input.block.stmts.len()); + for stmt in &input.block.stmts { + if stmt == internal_fun.source_stmt { + match internal_fun.kind { + // async-trait <= 0.1.43 + AsyncTraitKind::Function(fun) => { + out_stmts.push(gen_function( + fun, + args, + instrumented_function_name, + internal_fun.self_type, + )); + } + // async-trait >= 0.1.44 + AsyncTraitKind::Async(async_expr) => { + // fallback if we couldn't find the '__async_trait' binding, might be + // useful for crates exhibiting the same behaviors as async-trait + let instrumented_block = gen_block( + &async_expr.block, + &input.sig.inputs, + true, + args, + instrumented_function_name, + None, + ); + let async_attrs = &async_expr.attrs; + out_stmts.push(quote! { + Box::pin(#(#async_attrs) * async move { #instrumented_block }) + }); + } } + break; } } @@ -287,20 +305,21 @@ pub fn instrument( quote!( #(#attrs) * #vis #sig { - #(#stmts) * + #(#out_stmts) * } ) .into() } else { - gen_body(&input, args, instrumented_function_name, None).into() + gen_function(&input, args, instrumented_function_name, None).into() } } -fn gen_body( +/// Given an existing function, generate an instrumented version of that function +fn gen_function( input: &ItemFn, - mut args: InstrumentArgs, + args: InstrumentArgs, instrumented_function_name: String, - async_trait_fun: Option, + self_type: Option, ) -> proc_macro2::TokenStream { // these are needed ahead of time, as ItemFn contains the function body _and_ // isn't representable inside a quote!/quote_spanned! macro @@ -330,9 +349,39 @@ fn gen_body( .. } = sig; - let err = args.err; let warnings = args.warnings(); + let body = gen_block( + block, + params, + asyncness.is_some(), + args, + instrumented_function_name, + self_type, + ); + + quote!( + #(#attrs) * + #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type + #where_clause + { + #warnings + #body + } + ) +} + +/// Instrument a block +fn gen_block( + block: &Block, + params: &Punctuated, + async_context: bool, + mut args: InstrumentArgs, + instrumented_function_name: String, + self_type: Option, +) -> proc_macro2::TokenStream { + let err = args.err; + // generate the span's name let span_name = args // did the user override the span's name? @@ -353,8 +402,8 @@ fn gen_body( FnArg::Receiver(_) => Box::new(iter::once(Ident::new("self", param.span()))), }) // Little dance with new (user-exposed) names and old (internal) - // names of identifiers. That way, you can do the following - // even though async_trait rewrite "self" as "_self": + // names of identifiers. That way, we could do the following + // even though async_trait (<=0.1.43) rewrites "self" as "_self": // ``` // #[async_trait] // impl Foo for FooImpl { @@ -363,10 +412,9 @@ fn gen_body( // } // ``` .map(|x| { - // if we are inside a function generated by async-trait, we - // should take care to rewrite "_self" as "self" for - // 'user convenience' - if async_trait_fun.is_some() && x == "_self" { + // if we are inside a function generated by async-trait <=0.1.43, we need to + // take care to rewrite "_self" as "self" for 'user convenience' + if self_type.is_some() && x == "_self" { (Ident::new("self", x.span()), x) } else { (x.clone(), x) @@ -387,7 +435,7 @@ fn gen_body( // filter out skipped fields let quoted_fields: Vec<_> = param_names - .into_iter() + .iter() .filter(|(param, _)| { if args.skips.contains(param) { return false; @@ -407,13 +455,19 @@ fn gen_body( .map(|(user_name, real_name)| quote!(#user_name = tracing::field::debug(&#real_name))) .collect(); - // when async-trait is in use, replace instances of "self" with "_self" inside the fields values - if let (Some(ref async_trait_fun), Some(Fields(ref mut fields))) = - (async_trait_fun, &mut args.fields) - { - let mut replacer = SelfReplacer { - ty: async_trait_fun.self_type.clone(), + // replace every use of a variable with its original name + if let Some(Fields(ref mut fields)) = args.fields { + let mut replacer = IdentAndTypesRenamer { + idents: param_names, + types: Vec::new(), }; + + // when async-trait <=0.1.43 is in use, replace instances + // of the "Self" type inside the fields values + if let Some(self_type) = self_type { + replacer.types.push(("Self", self_type)); + } + for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) { syn::visit_mut::visit_expr_mut(&mut replacer, e); } @@ -436,9 +490,9 @@ fn gen_body( // which is `instrument`ed using `tracing-futures`. Otherwise, this will // enter the span and then perform the rest of the body. // If `err` is in args, instrument any resulting `Err`s. - let body = if asyncness.is_some() { + if async_context { if err { - quote_spanned! {block.span()=> + quote_spanned!(block.span()=> let __tracing_attr_span = #span; tracing::Instrument::instrument(async move { match async move { #block }.await { @@ -450,7 +504,7 @@ fn gen_body( } } }, __tracing_attr_span).await - } + ) } else { quote_spanned!(block.span()=> let __tracing_attr_span = #span; @@ -481,17 +535,7 @@ fn gen_body( let __tracing_attr_guard = __tracing_attr_span.enter(); #block ) - }; - - quote!( - #(#attrs) * - #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type - #where_clause - { - #warnings - #body - } - ) + } } #[derive(Default, Debug)] @@ -835,6 +879,20 @@ mod kw { syn::custom_keyword!(err); } +enum AsyncTraitKind<'a> { + // old construction. Contains the function + Function(&'a ItemFn), + // new construction. Contains a reference to the async block + Async(&'a ExprAsync), +} + +struct AsyncTraitInfo<'a> { + // statement that must be patched + source_stmt: &'a Stmt, + kind: AsyncTraitKind<'a>, + self_type: Option, +} + // Get the AST of the inner function we need to hook, if it was generated // by async-trait. // When we are given a function annotated by async-trait, that function @@ -842,118 +900,122 @@ mod kw { // user logic, and it is that pinned future that needs to be instrumented. // Were we to instrument its parent, we would only collect information // regarding the allocation of that future, and not its own span of execution. -// So we inspect the block of the function to find if it matches the pattern -// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` and we return -// the name `foo` if that is the case. 'gen_body' will then be able -// to use that information to instrument the proper function. +// Depending on the version of async-trait, we inspect the block of the function +// to find if it matches the pattern +// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` (<=0.1.43), or if +// it matches `Box::pin(async move { ... }) (>=0.1.44). We the return the +// statement that must be instrumented, along with some other informations. +// 'gen_body' will then be able to use that information to instrument the +// proper function/future. // (this follows the approach suggested in // https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673) -fn get_async_trait_function(block: &Block, block_is_async: bool) -> Option<&ItemFn> { +fn get_async_trait_info(block: &Block, block_is_async: bool) -> Option> { // are we in an async context? If yes, this isn't a async_trait-like pattern if block_is_async { return None; } // list of async functions declared inside the block - let mut inside_funs = Vec::new(); - // last expression declared in the block (it determines the return - // value of the block, so that if we are working on a function - // whose `trait` or `impl` declaration is annotated by async_trait, - // this is quite likely the point where the future is pinned) - let mut last_expr = None; - - // obtain the list of direct internal functions and the last - // expression of the block - for stmt in &block.stmts { + let inside_funs = block.stmts.iter().filter_map(|stmt| { if let Stmt::Item(Item::Fn(fun)) = &stmt { - // is the function declared as async? If so, this is a good - // candidate, let's keep it in hand + // If the function is async, this is a candidate if fun.sig.asyncness.is_some() { - inside_funs.push(fun); + return Some((stmt, fun)); } - } else if let Stmt::Expr(e) = &stmt { - last_expr = Some(e); } - } + None + }); - // let's play with (too much) pattern matching - // is the last expression a function call? - if let Some(Expr::Call(ExprCall { - func: outside_func, - args: outside_args, - .. - })) = last_expr - { - if let Expr::Path(path) = outside_func.as_ref() { - // is it a call to `Box::pin()`? - if "Box::pin" == path_to_string(&path.path) { - // does it takes at least an argument? (if it doesn't, - // it's not gonna compile anyway, but that's no reason - // to (try to) perform an out of bounds access) - if outside_args.is_empty() { - return None; - } - // is the argument to Box::pin a function call itself? - if let Expr::Call(ExprCall { func, .. }) = &outside_args[0] { - if let Expr::Path(inside_path) = func.as_ref() { - // "stringify" the path of the function called - let func_name = path_to_string(&inside_path.path); - // is this function directly defined insided the current block? - for fun in inside_funs { - if fun.sig.ident == func_name { - // we must hook this function now - return Some(fun); - } - } - } - } - } + // last expression of the block (it determines the return value + // of the block, so that if we are working on a function whose + // `trait` or `impl` declaration is annotated by async_trait, + // this is quite likely the point where the future is pinned) + let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| { + if let Stmt::Expr(expr) = stmt { + Some((stmt, expr)) + } else { + None } + })?; + + // is the last expression a function call? + let (outside_func, outside_args) = match last_expr { + Expr::Call(ExprCall { func, args, .. }) => (func, args), + _ => return None, + }; + + // is it a call to `Box::pin()`? + let path = match outside_func.as_ref() { + Expr::Path(path) => &path.path, + _ => return None, + }; + if !path_to_string(path).ends_with("Box::pin") { + return None; } - None -} -struct AsyncTraitInfo { - name: String, - self_type: Option, -} + // Does the call take an argument? If it doesn't, + // it's not gonna compile anyway, but that's no reason + // to (try to) perform an out of bounds access + if outside_args.is_empty() { + return None; + } -// Return the informations necessary to process a function annotated with async-trait. -fn get_async_trait_info(block: &Block, block_is_async: bool) -> Option { - let fun = get_async_trait_function(block, block_is_async)?; + // Is the argument to Box::pin an async block that + // captures its arguments? + if let Expr::Async(async_expr) = &outside_args[0] { + // check that the move 'keyword' is present + async_expr.capture?; - // if "_self" is present as an argument, we store its type to be able to rewrite "Self" (the + return Some(AsyncTraitInfo { + source_stmt: last_expr_stmt, + kind: AsyncTraitKind::Async(async_expr), + self_type: None, + }); + } + + // Is the argument to Box::pin a function call itself? + let func = match &outside_args[0] { + Expr::Call(ExprCall { func, .. }) => func, + _ => return None, + }; + + // "stringify" the path of the function called + let func_name = match **func { + Expr::Path(ref func_path) => path_to_string(&func_path.path), + _ => return None, + }; + + // Was that function defined inside of the current block? + // If so, retrieve the statement where it was declared and the function itself + let (stmt_func_declaration, func) = inside_funs + .into_iter() + .find(|(_, fun)| fun.sig.ident == func_name)?; + + // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the // parameter type) with the type of "_self" - let self_type = fun - .sig - .inputs - .iter() - .map(|arg| { - if let FnArg::Typed(ty) = arg { - if let Pat::Ident(PatIdent { ident, .. }) = &*ty.pat { - if ident == "_self" { - let mut ty = &*ty.ty; - // extract the inner type if the argument is "&self" or "&mut self" - if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty { - ty = &*elem; - } - if let syn::Type::Path(tp) = ty { - return Some(tp.clone()); - } + let mut self_type = None; + for arg in &func.sig.inputs { + if let FnArg::Typed(ty) = arg { + if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat { + if ident == "_self" { + let mut ty = *ty.ty.clone(); + // extract the inner type if the argument is "&self" or "&mut self" + if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty { + ty = *elem; + } + + if let syn::Type::Path(tp) = ty { + self_type = Some(tp); + break; } } } - - None - }) - .next(); - let self_type = match self_type { - Some(x) => x, - None => None, - }; + } + } Some(AsyncTraitInfo { - name: fun.sig.ident.to_string(), + source_stmt: stmt_func_declaration, + kind: AsyncTraitKind::Function(func), self_type, }) } @@ -973,26 +1035,48 @@ fn path_to_string(path: &Path) -> String { res } -// A visitor struct replacing the "self" and "Self" tokens in user-supplied fields expressions when -// the function is generated by async-trait. -struct SelfReplacer { - ty: Option, +/// A visitor struct to replace idents and types in some piece +/// of code (e.g. the "self" and "Self" tokens in user-supplied +/// fields expressions when the function is generated by an old +/// version of async-trait). +struct IdentAndTypesRenamer<'a> { + types: Vec<(&'a str, TypePath)>, + idents: Vec<(Ident, Ident)>, } -impl syn::visit_mut::VisitMut for SelfReplacer { +impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> { + // we deliberately compare strings because we want to ignore the spans + // If we apply clippy's lint, the behavior changes + #[allow(clippy::cmp_owned)] fn visit_ident_mut(&mut self, id: &mut Ident) { - if id == "self" { - *id = Ident::new("_self", id.span()) + for (old_ident, new_ident) in &self.idents { + if id.to_string() == old_ident.to_string() { + *id = new_ident.clone(); + } } } fn visit_type_mut(&mut self, ty: &mut syn::Type) { - if let syn::Type::Path(syn::TypePath { ref mut path, .. }) = ty { - if path_to_string(path) == "Self" { - if let Some(ref true_type) = self.ty { - *path = true_type.path.clone(); + for (type_name, new_type) in &self.types { + if let syn::Type::Path(TypePath { path, .. }) = ty { + if path_to_string(path) == *type_name { + *ty = syn::Type::Path(new_type.clone()); } } } } } + +// A visitor struct that replace an async block by its patched version +struct AsyncTraitBlockReplacer<'a> { + block: &'a Block, + patched_block: Block, +} + +impl<'a> syn::visit_mut::VisitMut for AsyncTraitBlockReplacer<'a> { + fn visit_block_mut(&mut self, i: &mut Block) { + if i == self.block { + *i = self.patched_block.clone(); + } + } +} diff --git a/tracing-attributes/tests/async_fn.rs b/tracing-attributes/tests/async_fn.rs index 71bafbf994..dcd502f31b 100644 --- a/tracing-attributes/tests/async_fn.rs +++ b/tracing-attributes/tests/async_fn.rs @@ -172,18 +172,19 @@ fn async_fn_with_async_trait_and_fields_expressions() { #[async_trait] impl Test for TestImpl { // check that self is correctly handled, even when using async_trait - #[instrument(fields(val=self.foo(), test=%v+5))] - async fn call(&mut self, v: usize) {} + #[instrument(fields(val=self.foo(), val2=Self::clone(self).foo(), test=%_v+5))] + async fn call(&mut self, _v: usize) {} } let span = span::mock().named("call"); let (collector, handle) = collector::mock() .new_span( span.clone().with_field( - field::mock("v") + field::mock("_v") .with_value(&tracing::field::debug(5)) .and(field::mock("test").with_value(&tracing::field::debug(10))) - .and(field::mock("val").with_value(&42u64)), + .and(field::mock("val").with_value(&42u64)) + .and(field::mock("val2").with_value(&42u64)), ), ) .enter(span.clone())