From 3b02fe5aeb406741aa3521d908345c04f3bd08b5 Mon Sep 17 00:00:00 2001 From: Alex Kirszenberg Date: Fri, 14 Jul 2023 16:32:19 +0200 Subject: [PATCH] AdjacencyMap::reverse_topological --- .../src/derive/trace_raw_vcs_macro.rs | 13 ++++-- .../src/derive/value_debug_format_macro.rs | 17 ++++++-- crates/turbo-tasks/src/graph/adjacency_map.rs | 41 ++++++++++++++----- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs b/crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs index 813366f540b03e..ae77e53e7560df 100644 --- a/crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs +++ b/crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs @@ -11,14 +11,19 @@ fn filter_field(field: &Field) -> bool { } pub fn derive_trace_raw_vcs(input: TokenStream) -> TokenStream { - let derive_input = parse_macro_input!(input as DeriveInput); + let mut derive_input = parse_macro_input!(input as DeriveInput); let ident = &derive_input.ident; - let generics = &derive_input.generics; + + for type_param in derive_input.generics.type_params_mut() { + type_param + .bounds + .push(syn::parse_quote!(turbo_tasks::trace::TraceRawVcs)); + } + let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl(); let trace_items = match_expansion(&derive_input, &trace_named, &trace_unnamed, &trace_unit); - let generics_params = &generics.params.iter().collect::>(); quote! { - impl #generics turbo_tasks::trace::TraceRawVcs for #ident #generics #(where #generics_params: turbo_tasks::trace::TraceRawVcs)* { + impl #impl_generics turbo_tasks::trace::TraceRawVcs for #ident #ty_generics #where_clause { fn trace_raw_vcs(&self, __context__: &mut turbo_tasks::trace::TraceRawVcsContext) { #trace_items } diff --git a/crates/turbo-tasks-macros/src/derive/value_debug_format_macro.rs b/crates/turbo-tasks-macros/src/derive/value_debug_format_macro.rs index 29cf401506477c..291fe1f910504e 100644 --- a/crates/turbo-tasks-macros/src/derive/value_debug_format_macro.rs +++ b/crates/turbo-tasks-macros/src/derive/value_debug_format_macro.rs @@ -16,16 +16,27 @@ fn filter_field(field: &Field) -> bool { /// Fields annotated with `#[debug_ignore]` will not appear in the /// `ValueDebugFormat` representation of the type. pub fn derive_value_debug_format(input: TokenStream) -> TokenStream { - let derive_input = parse_macro_input!(input as DeriveInput); + let mut derive_input = parse_macro_input!(input as DeriveInput); let ident = &derive_input.ident; + + for type_param in derive_input.generics.type_params_mut() { + type_param + .bounds + .push(syn::parse_quote!(turbo_tasks::debug::ValueDebugFormat)); + type_param.bounds.push(syn::parse_quote!(std::fmt::Debug)); + type_param.bounds.push(syn::parse_quote!(std::marker::Send)); + type_param.bounds.push(syn::parse_quote!(std::marker::Sync)); + } + let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl(); + let formatting_logic = match_expansion(&derive_input, &format_named, &format_unnamed, &format_unit); let value_debug_format_ident = get_value_debug_format_ident(ident); quote! { - impl #ident { + impl #impl_generics #ident #ty_generics #where_clause { #[doc(hidden)] #[allow(non_snake_case)] async fn #value_debug_format_ident(&self, depth: usize) -> anyhow::Result { @@ -39,7 +50,7 @@ pub fn derive_value_debug_format(input: TokenStream) -> TokenStream { } } - impl turbo_tasks::debug::ValueDebugFormat for #ident { + impl #impl_generics turbo_tasks::debug::ValueDebugFormat for #ident #ty_generics #where_clause { fn value_debug_format<'a>(&'a self, depth: usize) -> turbo_tasks::debug::ValueDebugFormatString<'a> { turbo_tasks::debug::ValueDebugFormatString::Async( Box::pin(async move { diff --git a/crates/turbo-tasks/src/graph/adjacency_map.rs b/crates/turbo-tasks/src/graph/adjacency_map.rs index cd85614d18f5f7..b5aa8683bd6235 100644 --- a/crates/turbo-tasks/src/graph/adjacency_map.rs +++ b/crates/turbo-tasks/src/graph/adjacency_map.rs @@ -1,8 +1,13 @@ use std::collections::{HashMap, HashSet}; +use serde::{Deserialize, Serialize}; +use turbo_tasks_macros::{TraceRawVcs, ValueDebugFormat}; + use super::graph_store::{GraphNode, GraphStore}; +use crate as turbo_tasks; /// A graph traversal that builds an adjacency map +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TraceRawVcs, ValueDebugFormat)] pub struct AdjacencyMap where T: Eq + std::hash::Hash + Clone, @@ -68,10 +73,10 @@ impl AdjacencyMap where T: Eq + std::hash::Hash + Clone, { - /// Returns an iterator over the nodes in reverse topological order, + /// Returns an owned iterator over the nodes in reverse topological order, /// starting from the roots. - pub fn into_reverse_topological(self) -> ReverseTopologicalIter { - ReverseTopologicalIter { + pub fn into_reverse_topological(self) -> IntoReverseTopologicalIter { + IntoReverseTopologicalIter { adjacency_map: self.adjacency_map, stack: self .roots @@ -82,13 +87,27 @@ where } } + /// Returns an iterator over the nodes in reverse topological order, + /// starting from the roots. + pub fn reverse_topological<'graph>(&'graph self) -> ReverseTopologicalIter<'graph, T> { + ReverseTopologicalIter { + adjacency_map: &self.adjacency_map, + stack: self + .roots + .iter() + .map(|root| (ReverseTopologicalPass::Pre, root)) + .collect(), + visited: HashSet::new(), + } + } + /// Returns an iterator over the nodes in reverse topological order, /// starting from the given node. pub fn reverse_topological_from_node<'graph>( &'graph self, node: &'graph T, - ) -> ReverseTopologicalFromNodeIter<'graph, T> { - ReverseTopologicalFromNodeIter { + ) -> ReverseTopologicalIter<'graph, T> { + ReverseTopologicalIter { adjacency_map: &self.adjacency_map, stack: vec![(ReverseTopologicalPass::Pre, node)], visited: HashSet::new(), @@ -104,7 +123,7 @@ enum ReverseTopologicalPass { /// An iterator over the nodes of a graph in reverse topological order, starting /// from the roots. -pub struct ReverseTopologicalIter +pub struct IntoReverseTopologicalIter where T: Eq + std::hash::Hash + Clone, { @@ -113,7 +132,7 @@ where visited: HashSet, } -impl Iterator for ReverseTopologicalIter +impl Iterator for IntoReverseTopologicalIter where T: Eq + std::hash::Hash + Clone, { @@ -153,8 +172,8 @@ where } /// An iterator over the nodes of a graph in reverse topological order, starting -/// from a given node. -pub struct ReverseTopologicalFromNodeIter<'graph, T> +/// from the roots. +pub struct ReverseTopologicalIter<'graph, T> where T: Eq + std::hash::Hash + Clone, { @@ -163,7 +182,7 @@ where visited: HashSet<&'graph T>, } -impl<'graph, T> Iterator for ReverseTopologicalFromNodeIter<'graph, T> +impl<'graph, T> Iterator for ReverseTopologicalIter<'graph, T> where T: Eq + std::hash::Hash + Clone, { @@ -178,7 +197,7 @@ where break current; } ReverseTopologicalPass::Pre => { - if self.visited.contains(¤t) { + if self.visited.contains(current) { continue; }