Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tracing-attributes] Support for using #[instrument] with async-trait #711

Merged
merged 4 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ async fn write(stream: &mut TcpStream) -> io::Result<usize> {
Under the hood, the `#[instrument]` macro performs same the explicit span
attachment that `Future::instrument` does.

Note: the [`#[tracing::instrument]`](https://github.com/tokio-rs/tracing/issues/399)` macro does not work correctly with the [async-trait](https://github.com/dtolnay/async-trait) crate. This bug is tracked in [#399](https://github.com/tokio-rs/tracing/issues/399).

## Getting Help

First, see if the answer to your question can be found in the API documentation.
Expand Down
2 changes: 2 additions & 0 deletions tracing-attributes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ async-await = []
[dependencies]
syn = { version = "1", features = ["full", "extra-traits"] }
quote = "1"
proc-macro2 = "1"


[dev-dependencies]
tracing = { path = "../tracing", version = "0.1" }
tracing-futures = { path = "../tracing-futures", version = "0.2" }
tokio-test = { version = "0.2.0" }
tracing-core = { path = "../tracing-core", version = "0.1"}
async-trait = "0.1"

[badges]
maintenance = { status = "experimental" }
233 changes: 209 additions & 24 deletions tracing-attributes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ use std::iter;
use proc_macro::TokenStream;
use quote::{quote, quote_spanned, ToTokens};
use syn::{
spanned::Spanned, AttributeArgs, FieldPat, FnArg, Ident, ItemFn, Lit, LitInt, Meta, MetaList,
MetaNameValue, NestedMeta, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct,
PatType, Path, Signature,
spanned::Spanned, AttributeArgs, Block, Expr, ExprCall, FieldPat, FnArg, Ident, Item, ItemFn,
Lit, LitInt, Meta, MetaList, MetaNameValue, NestedMeta, Pat, PatIdent, PatReference, PatStruct,
PatTuple, PatTupleStruct, PatType, Path, Signature, Stmt,
};

/// Instruments a function to create and enter a `tracing` [span] every time
Expand Down Expand Up @@ -168,6 +168,31 @@ use syn::{
/// # Ok(())
/// }
/// ```
///
/// It also works with [async-trait](https://crates.io/crates/async-trait)
/// (a crate that allows async functions on traits,
/// something not currently possible with rustc alone),
/// and hopefully most libraries that exhibit similar behaviors:
///
/// ```
/// # use tracing::instrument;
/// use async_trait::async_trait;
///
/// #[async_trait]
/// pub trait Foo {
/// async fn foo(&self, v: usize) -> ();
/// }
///
/// #[derive(Debug)]
/// struct FooImpl;
///
/// #[async_trait]
/// impl Foo for FooImpl {
/// #[instrument(skip(self))]
/// async fn foo(&self, v: usize) {}
/// }
/// ```

///
/// [span]: https://docs.rs/tracing/latest/tracing/span/index.html
/// [`tracing`]: https://github.com/tokio-rs/tracing
Expand All @@ -177,6 +202,44 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
let input: ItemFn = syn::parse_macro_input!(item as ItemFn);
let args = syn::parse_macro_input!(args as AttributeArgs);

// check for async_trait-like patterns in the block and wrap the
// internal function with Instrument instead of wrapping the
// async_trait generated wrapper
if let Some(internal_fun_name) =
get_async_trait_name(&input.block, input.sig.asyncness.is_some())
{
// let's rewrite some statements!
let mut stmts: Vec<Stmt> = input.block.stmts.to_vec();
for stmt in &mut stmts {
if let Stmt::Item(Item::Fn(fun)) = stmt {
// instrument the function if we considered it as the one we truly want to trace
if fun.sig.ident == internal_fun_name {
*stmt = syn::parse2(gen_body(fun, args, Some(input.sig.ident.to_string())))
.unwrap();
break;
}
}
}

let sig = &input.sig;
let attrs = &input.attrs;
quote!(
#(#attrs) *
#sig {
#(#stmts) *
}
)
.into()
} else {
gen_body(&input, args, None).into()
}
}

fn gen_body(
input: &ItemFn,
args: AttributeArgs,
fun_name: Option<String>,
) -> proc_macro2::TokenStream {
// these are needed ahead of time, as ItemFn contains the function body _and_
// isn't representable inside a quote!/quote_spanned! macro
// (Syn's ToTokens isn't implemented for ItemFn)
Expand Down Expand Up @@ -206,7 +269,11 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
} = sig;

// function name
let ident_str = ident.to_string();
let ident_str = if let Some(x) = &fun_name {
x.clone()
} else {
ident.to_string()
};

// generate this inside a closure, so we can return early on errors.
let span = (|| {
Expand All @@ -216,42 +283,67 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
Err(err) => return quote!(#err),
};

let param_names: Vec<Ident> = params
let param_names: Vec<(Ident, Ident)> = params
.clone()
.into_iter()
.flat_map(|param| match param {
FnArg::Typed(PatType { pat, .. }) => param_names(*pat),
FnArg::Receiver(_) => Box::new(iter::once(Ident::new("self", param.span()))),
})
// if we are inside a function generated by async-trait, we
// should take care to rewrite "_self" as "self" for
// 'user convenience'
.map(|x| {
if fun_name.is_some() && x == "_self" {
(Ident::new("self", x.span()), x)
} else {
(x.clone(), x)
}
})
.collect();

// TODO: allow the user to rename fields at will (all the
// machinery should be here)

// Little dance with new (user-exposed) names and old (internal)
// names of identifiers. That way, you can do the following
// even though async_trait rewrite "self" as "_self":
// ```
// #[async_trait]
// impl Foo for FooImpl {
// #[instrument(skip(self))]
// async fn foo(&self, v: usize) {}
// }
// ```

for skip in &skips {
if !param_names.contains(skip) {
if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
return quote_spanned! {skip.span()=>
compile_error!("attempting to skip non-existent parameter")
};
}
}

let param_names: Vec<Ident> = param_names
// filter out skipped fields
let param_names: Vec<(Ident, Ident)> = param_names
.into_iter()
.filter(|ident| !skips.contains(ident))
.filter(|(ident, _)| !skips.contains(ident))
.collect();

let fields = match fields(&args, &param_names) {
let new_param_names: Vec<&Ident> = param_names.iter().map(|x| &x.0).collect();

let fields = match fields(&args, &new_param_names) {
Ok(fields) => fields,
Err(err) => return quote!(#err),
};

let param_names_clone = param_names.clone();

let level = level(&args);
let target = target(&args);
let span_name = name(&args, ident_str);

let mut quoted_fields: Vec<_> = param_names
.into_iter()
.map(|i| quote!(#i = tracing::field::debug(&#i)))
.iter()
.map(|(user_name, real_name)| quote!(#user_name = tracing::field::debug(&#real_name)))
.collect();
quoted_fields.extend(fields.into_iter().map(|(key, value)| {
let value = match value {
Expand All @@ -276,7 +368,8 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
// If `err` is in args, instrument any resulting `Err`s.
let body = if asyncness.is_some() {
if instrument_err(&args) {
quote_spanned! {block.span()=>
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
tracing_futures::Instrument::instrument(async move {
match async move { #block }.await {
Ok(x) => Ok(x),
Expand All @@ -286,18 +379,20 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
}
}
}, __tracing_attr_span).await
}
)
} else {
quote_spanned! {block.span()=>
tracing_futures::Instrument::instrument(
async move { #block },
__tracing_attr_span
)
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
tracing_futures::Instrument::instrument(
async move { #block },
__tracing_attr_span
)
.await
}
)
}
} else if instrument_err(&args) {
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
let __tracing_attr_guard = __tracing_attr_span.enter();
match { #block } {
Ok(x) => Ok(x),
Expand All @@ -309,6 +404,7 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
)
} else {
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
let __tracing_attr_guard = __tracing_attr_span.enter();
#block
)
Expand All @@ -319,11 +415,9 @@ pub fn instrument(args: TokenStream, item: TokenStream) -> TokenStream {
#vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
#where_clause
{
let __tracing_attr_span = #span;
#body
}
)
.into()
}

fn param_names(pat: Pat) -> Box<dyn Iterator<Item = Ident>> {
Expand Down Expand Up @@ -462,7 +556,7 @@ fn target(args: &[NestedMeta]) -> impl ToTokens {

fn fields(
args: &[NestedMeta],
param_names: &[Ident],
param_names: &[&Ident],
) -> Result<(Vec<(Ident, Option<Lit>)>), impl ToTokens> {
let mut fields = args.iter().filter_map(|arg| match arg {
NestedMeta::Meta(Meta::List(MetaList {
Expand Down Expand Up @@ -579,3 +673,94 @@ fn instrument_err(args: &[NestedMeta]) -> bool {
_ => false,
})
}

// Get the name of the inner function we need to hook, if the function
// was generated by async-trait.
// When we are given a function generated by async-trait, that function
// is only a "temporary" one that returns a pinned future, and it is
// that pinned future that needs to be instrumented, otherwise we will
// only collect information on the moment the future was "built",
// and not its true span of execution.
// So we inspect the block of the function to find if we can find the
// pattern `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` and
// return the name `foo` if that is the case. Our caller will then be
// able to use that information to instrument the proper function.
// (this follows the approach suggested in
// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
fn get_async_trait_name(block: &Block, block_is_async: bool) -> Option<String> {
// are we in an async context? If yes, this isn't a async_trait-like pattern
if block_is_async {
return None;
}

// list of async functions declared inside the block
let mut inside_funs = Vec::new();
// last expression declared in the block (it determines the return
// value of the block, so that if we are working on a function
// whose `trait` or `impl` declaration is annotated by async_trait,
// this is quite likely the point where the future is pinned)
let mut last_expr = None;

// obtain the list of direct internal functions and the last
// expression of the block
for stmt in &block.stmts {
if let Stmt::Item(Item::Fn(fun)) = &stmt {
// is the function declared as async? If so, this is a good
// candidate, let's keep it in hand
if fun.sig.asyncness.is_some() {
inside_funs.push(fun.sig.ident.to_string());
}
} else if let Stmt::Expr(e) = &stmt {
last_expr = Some(e);
}
}

// let's play with (too much) pattern matching
// is the last expression a function call?
if let Some(Expr::Call(ExprCall {
func: outside_func,
args: outside_args,
..
})) = last_expr
{
if let Expr::Path(path) = outside_func.as_ref() {
// is it a call to `Box::pin()`?
if "Box::pin" == path_to_string(&path.path) {
// does it takes at least an argument? (if it doesn't,
// it's not gonna compile anyway, but that's no reason
// to (try to) perform an out of bounds access)
if outside_args.is_empty() {
return None;
}
// is the argument to Box::pin a function call itself?
if let Expr::Call(ExprCall { func, args, .. }) = &outside_args[0] {
if let Expr::Path(inside_path) = func.as_ref() {
// "stringify" the path of the function called
let func_name = path_to_string(&inside_path.path);
// is this function directly defined insided the current block?
if inside_funs.contains(&func_name) {
// we must hook this function now
return Some(func_name);
}
}
}
}
}
}
None
}

// Return a path as a String
fn path_to_string(path: &Path) -> String {
use std::fmt::Write;
// some heuristic to prevent too many allocations
let mut res = String::with_capacity(path.segments.len() * 5);
for i in 0..path.segments.len() {
write!(&mut res, "{}", path.segments[i].ident)
.expect("writing to a String should never fail");
if i < path.segments.len() - 1 {
res.push_str("::");
}
}
res
}
Loading