diff --git a/crates/turbo-tasks-macros/src/func.rs b/crates/turbo-tasks-macros/src/func.rs index 6afda74a7bad8..a02eb175ae06b 100644 --- a/crates/turbo-tasks-macros/src/func.rs +++ b/crates/turbo-tasks-macros/src/func.rs @@ -1,7 +1,11 @@ use proc_macro2::Ident; use syn::{ - parse_quote, punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprPath, FnArg, Pat, - PatIdent, PatType, Path, Receiver, ReturnType, Signature, Token, Type, + parse_quote, + punctuated::{Pair, Punctuated}, + spanned::Spanned, + AngleBracketedGenericArguments, Block, Expr, ExprPath, FnArg, GenericArgument, Pat, PatIdent, + PatType, Path, PathArguments, PathSegment, Receiver, ReturnType, Signature, Token, Type, + TypeGroup, TypePath, TypeTuple, }; #[derive(Debug)] @@ -257,10 +261,11 @@ impl TurboFn { .collect(); let ident = &self.ident; - let output = &self.output; + let orig_output = &self.output; + let new_output = expand_vc_return_type(orig_output); parse_quote! { - fn #ident(#exposed_inputs) -> <#output as turbo_tasks::task::TaskOutput>::Return + fn #ident(#exposed_inputs) -> #new_output } } @@ -327,6 +332,100 @@ fn return_type_to_type(return_type: &ReturnType) -> Type { } } +fn expand_vc_return_type(orig_output: &Type) -> Type { + // HACK: Approximate the expansion that we'd otherwise get from + // `::Return`, so that the return type shown in the rustdocs + // is as simple as possible. Break out as soon as we see something we don't + // recognize. + let mut new_output = orig_output.clone(); + let mut found_vc = false; + loop { + new_output = match new_output { + Type::Group(TypeGroup { elem, .. }) => *elem, + Type::Tuple(TypeTuple { elems, .. }) if elems.is_empty() => { + Type::Path(parse_quote!(::turbo_tasks::Vc<()>)) + } + Type::Path(TypePath { + qself: None, + path: + Path { + leading_colon, + ref segments, + }, + }) => { + let mut pairs = segments.pairs(); + let mut cur_pair = pairs.next(); + + enum PathPrefix { + Anyhow, + TurboTasks, + } + + // try to strip a `turbo_tasks::` or `anyhow::` prefix + let prefix = if let Some(first) = cur_pair.as_ref().map(|p| p.value()) { + if first.arguments.is_none() { + if first.ident == "turbo_tasks" { + Some(PathPrefix::TurboTasks) + } else if first.ident == "anyhow" { + Some(PathPrefix::Anyhow) + } else { + None + } + } else { + None + } + } else { + None + }; + + if prefix.is_some() { + cur_pair = pairs.next(); // strip the matched prefix + } else if leading_colon.is_some() { + break; // something like `::Vc` isn't valid + } + + // Look for a `Vc<...>` or `Result<...>` generic + let Some(Pair::End(PathSegment { + ident, + arguments: + PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }), + })) = cur_pair + else { + break; + }; + if ident == "Vc" { + found_vc = true; + break; // Vc is the bottom-most level + } + if ident == "Result" && args.len() == 1 { + let GenericArgument::Type(ty) = + args.first().expect("Result<...> type has an argument") + else { + break; + }; + ty.clone() + } else { + break; // we only support expanding Result<...> + } + } + _ => break, + } + } + + if !found_vc { + orig_output + .span() + .unwrap() + .error( + "Expected return type to be `turbo_tasks::Vc` or `anyhow::Result>`. \ + Unable to process type.", + ) + .emit(); + } + + new_output +} + /// The context in which the function is being defined. #[derive(Debug, Clone, Eq, PartialEq)] pub enum DefinitionContext {