diff --git a/tracing-attributes/src/lib.rs b/tracing-attributes/src/lib.rs index 54f5efebe5..639adad445 100644 --- a/tracing-attributes/src/lib.rs +++ b/tracing-attributes/src/lib.rs @@ -264,39 +264,45 @@ pub fn instrument( // the future instead of the wrapper if let Some(internal_fun) = get_async_trait_info(&input.block, input.sig.asyncness.is_some()) { // let's rewrite some statements! - let mut out_stmts = Vec::with_capacity(input.block.stmts.len()); - for stmt in &input.block.stmts { - if stmt == internal_fun.source_stmt { - match internal_fun.kind { - // async-trait <= 0.1.43 - AsyncTraitKind::Function(fun) => { - out_stmts.push(gen_function( - fun, - args, - instrumented_function_name, - internal_fun.self_type, - )); - } - // async-trait >= 0.1.44 - AsyncTraitKind::Async(async_expr) => { - // fallback if we couldn't find the '__async_trait' binding, might be - // useful for crates exhibiting the same behaviors as async-trait - let instrumented_block = gen_block( - &async_expr.block, - &input.sig.inputs, - true, - args, - instrumented_function_name, - None, - ); - let async_attrs = &async_expr.attrs; - out_stmts.push(quote! { - Box::pin(#(#async_attrs) * async move { #instrumented_block }) - }); + let mut out_stmts: Vec = input + .block + .stmts + .iter() + .map(|stmt| stmt.to_token_stream()) + .collect(); + + if let Some((iter, _stmt)) = input + .block + .stmts + .iter() + .enumerate() + .find(|(_iter, stmt)| *stmt == internal_fun.source_stmt) + { + // instrument the future by rewriting the corresponding statement + out_stmts[iter] = match internal_fun.kind { + // async-trait <= 0.1.43 + AsyncTraitKind::Function(fun) => gen_function( + fun, + args, + instrumented_function_name.as_str(), + internal_fun.self_type.as_ref(), + ), + // async-trait >= 0.1.44 + AsyncTraitKind::Async(async_expr) => { + let instrumented_block = gen_block( + &async_expr.block, + &input.sig.inputs, + true, + args, + instrumented_function_name.as_str(), + None, + ); + let async_attrs = &async_expr.attrs; + quote! { + Box::pin(#(#async_attrs) * async move { #instrumented_block }) } } - break; - } + }; } let vis = &input.vis; @@ -310,7 +316,7 @@ pub fn instrument( ) .into() } else { - gen_function(&input, args, instrumented_function_name, None).into() + gen_function(&input, args, instrumented_function_name.as_str(), None).into() } } @@ -318,8 +324,8 @@ pub fn instrument( fn gen_function( input: &ItemFn, args: InstrumentArgs, - instrumented_function_name: String, - self_type: Option, + instrumented_function_name: &str, + self_type: Option<&syn::TypePath>, ) -> proc_macro2::TokenStream { // these are needed ahead of time, as ItemFn contains the function body _and_ // isn't representable inside a quote!/quote_spanned! macro @@ -377,8 +383,8 @@ fn gen_block( params: &Punctuated, async_context: bool, mut args: InstrumentArgs, - instrumented_function_name: String, - self_type: Option, + instrumented_function_name: &str, + self_type: Option<&syn::TypePath>, ) -> proc_macro2::TokenStream { let err = args.err; @@ -465,7 +471,7 @@ fn gen_block( // when async-trait <=0.1.43 is in use, replace instances // of the "Self" type inside the fields values if let Some(self_type) = self_type { - replacer.types.push(("Self", self_type)); + replacer.types.push(("Self", self_type.clone())); } for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) { diff --git a/tracing-attributes/tests/async_fn.rs b/tracing-attributes/tests/async_fn.rs index dcd502f31b..54e945967c 100644 --- a/tracing-attributes/tests/async_fn.rs +++ b/tracing-attributes/tests/async_fn.rs @@ -4,6 +4,7 @@ mod support; use support::*; +use std::{future::Future, pin::Pin, sync::Arc}; use tracing::collect::with_default; use tracing_attributes::instrument; @@ -214,6 +215,18 @@ fn async_fn_with_async_trait_and_fields_expressions_with_generic_parameter() { #[derive(Clone, Debug)] struct TestImpl; + // we also test sync functions that return futures, as they should be handled just like + // async-trait (>= 0.1.44) functions + impl TestImpl { + #[instrument(fields(Self=std::any::type_name::()))] + fn sync_fun(&self) -> Pin + Send + '_>> { + let val = self.clone(); + Box::pin(async move { + let _ = val; + }) + } + } + #[async_trait] impl Test for TestImpl { // instrumenting this is currently not possible, see https://github.com/tokio-rs/tracing/issues/864#issuecomment-667508801 @@ -221,7 +234,9 @@ fn async_fn_with_async_trait_and_fields_expressions_with_generic_parameter() { async fn call() {} #[instrument(fields(Self=std::any::type_name::()))] - async fn call_with_self(&self) {} + async fn call_with_self(&self) { + self.sync_fun().await; + } #[instrument(fields(Self=std::any::type_name::()))] async fn call_with_mut_self(&mut self) {} @@ -230,6 +245,7 @@ fn async_fn_with_async_trait_and_fields_expressions_with_generic_parameter() { //let span = span::mock().named("call"); let span2 = span::mock().named("call_with_self"); let span3 = span::mock().named("call_with_mut_self"); + let span4 = span::mock().named("sync_fun"); let (collector, handle) = collector::mock() /*.new_span(span.clone() .with_field( @@ -243,6 +259,13 @@ fn async_fn_with_async_trait_and_fields_expressions_with_generic_parameter() { .with_field(field::mock("Self").with_value(&std::any::type_name::())), ) .enter(span2.clone()) + .new_span( + span4 + .clone() + .with_field(field::mock("Self").with_value(&std::any::type_name::())), + ) + .enter(span4.clone()) + .exit(span4) .exit(span2.clone()) .drop_span(span2) .new_span( @@ -266,3 +289,45 @@ fn async_fn_with_async_trait_and_fields_expressions_with_generic_parameter() { handle.assert_finished(); } + +#[test] +fn out_of_scope_fields() { + // Reproduces tokio-rs/tracing#1296 + + struct Thing { + metrics: Arc<()>, + } + + impl Thing { + #[instrument(skip(self, _req), fields(app_id))] + fn call(&mut self, _req: ()) -> Pin> + Send + Sync>> { + // ... + let metrics = self.metrics.clone(); + // ... + Box::pin(async move { + // ... + metrics // cannot find value `metrics` in this scope + }) + } + } + + let span = span::mock().named("call"); + let (collector, handle) = collector::mock() + .new_span(span.clone()) + .enter(span.clone()) + .exit(span.clone()) + .drop_span(span) + .done() + .run_with_handle(); + + with_default(collector, || { + block_on_future(async { + let mut my_thing = Thing { + metrics: Arc::new(()), + }; + my_thing.call(()).await; + }); + }); + + handle.assert_finished(); +}