Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdjacencyMap::reverse_topological (+ fixes) #5527

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/turbo-tasks-macros/src/derive/trace_raw_vcs_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ fn filter_field(field: &Field) -> bool {
}

pub fn derive_trace_raw_vcs(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
let mut derive_input = parse_macro_input!(input as DeriveInput);
let ident = &derive_input.ident;
let generics = &derive_input.generics;

for type_param in derive_input.generics.type_params_mut() {
type_param
.bounds
.push(syn::parse_quote!(turbo_tasks::trace::TraceRawVcs));
}
let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();

let trace_items = match_expansion(&derive_input, &trace_named, &trace_unnamed, &trace_unit);
let generics_params = &generics.params.iter().collect::<Vec<_>>();
quote! {
impl #generics turbo_tasks::trace::TraceRawVcs for #ident #generics #(where #generics_params: turbo_tasks::trace::TraceRawVcs)* {
impl #impl_generics turbo_tasks::trace::TraceRawVcs for #ident #ty_generics #where_clause {
fn trace_raw_vcs(&self, __context__: &mut turbo_tasks::trace::TraceRawVcsContext) {
#trace_items
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,28 @@ fn filter_field(field: &Field) -> bool {
/// Fields annotated with `#[debug_ignore]` will not appear in the
/// `ValueDebugFormat` representation of the type.
pub fn derive_value_debug_format(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
let mut derive_input = parse_macro_input!(input as DeriveInput);

let ident = &derive_input.ident;

for type_param in derive_input.generics.type_params_mut() {
type_param
.bounds
.push(syn::parse_quote!(turbo_tasks::debug::ValueDebugFormat));
type_param.bounds.push(syn::parse_quote!(std::fmt::Debug));
type_param.bounds.push(syn::parse_quote!(std::marker::Send));
type_param.bounds.push(syn::parse_quote!(std::marker::Sync));
}
let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();

let formatting_logic =
match_expansion(&derive_input, &format_named, &format_unnamed, &format_unit);

let value_debug_format_ident = get_value_debug_format_ident(ident);

quote! {
#[doc(hidden)]
impl #ident {
impl #impl_generics #ident #ty_generics #where_clause {
#[doc(hidden)]
#[allow(non_snake_case)]
async fn #value_debug_format_ident(&self, depth: usize) -> anyhow::Result<turbo_tasks::Vc<turbo_tasks::debug::ValueDebugString>> {
Expand All @@ -40,7 +51,7 @@ pub fn derive_value_debug_format(input: TokenStream) -> TokenStream {
}
}

impl turbo_tasks::debug::ValueDebugFormat for #ident {
impl #impl_generics turbo_tasks::debug::ValueDebugFormat for #ident #ty_generics #where_clause {
fn value_debug_format<'a>(&'a self, depth: usize) -> turbo_tasks::debug::ValueDebugFormatString<'a> {
turbo_tasks::debug::ValueDebugFormatString::Async(
Box::pin(async move {
Expand Down
41 changes: 30 additions & 11 deletions crates/turbo-tasks/src/graph/adjacency_map.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};
use turbo_tasks_macros::{TraceRawVcs, ValueDebugFormat};

use super::graph_store::{GraphNode, GraphStore};
use crate as turbo_tasks;

/// A graph traversal that builds an adjacency map
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TraceRawVcs, ValueDebugFormat)]
pub struct AdjacencyMap<T>
where
T: Eq + std::hash::Hash + Clone,
Expand Down Expand Up @@ -68,10 +73,10 @@ impl<T> AdjacencyMap<T>
where
T: Eq + std::hash::Hash + Clone,
{
/// Returns an iterator over the nodes in reverse topological order,
/// Returns an owned iterator over the nodes in reverse topological order,
/// starting from the roots.
pub fn into_reverse_topological(self) -> ReverseTopologicalIter<T> {
ReverseTopologicalIter {
pub fn into_reverse_topological(self) -> IntoReverseTopologicalIter<T> {
IntoReverseTopologicalIter {
adjacency_map: self.adjacency_map,
stack: self
.roots
Expand All @@ -82,13 +87,27 @@ where
}
}

/// Returns an iterator over the nodes in reverse topological order,
/// starting from the roots.
pub fn reverse_topological(&self) -> ReverseTopologicalIter<T> {
ReverseTopologicalIter {
adjacency_map: &self.adjacency_map,
stack: self
.roots
.iter()
.map(|root| (ReverseTopologicalPass::Pre, root))
.collect(),
visited: HashSet::new(),
}
}

/// Returns an iterator over the nodes in reverse topological order,
/// starting from the given node.
pub fn reverse_topological_from_node<'graph>(
&'graph self,
node: &'graph T,
) -> ReverseTopologicalFromNodeIter<'graph, T> {
ReverseTopologicalFromNodeIter {
) -> ReverseTopologicalIter<'graph, T> {
ReverseTopologicalIter {
adjacency_map: &self.adjacency_map,
stack: vec![(ReverseTopologicalPass::Pre, node)],
visited: HashSet::new(),
Expand All @@ -104,7 +123,7 @@ enum ReverseTopologicalPass {

/// An iterator over the nodes of a graph in reverse topological order, starting
/// from the roots.
pub struct ReverseTopologicalIter<T>
pub struct IntoReverseTopologicalIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -113,7 +132,7 @@ where
visited: HashSet<T>,
}

impl<T> Iterator for ReverseTopologicalIter<T>
impl<T> Iterator for IntoReverseTopologicalIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand Down Expand Up @@ -153,8 +172,8 @@ where
}

/// An iterator over the nodes of a graph in reverse topological order, starting
/// from a given node.
pub struct ReverseTopologicalFromNodeIter<'graph, T>
/// from the roots.
pub struct ReverseTopologicalIter<'graph, T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -163,7 +182,7 @@ where
visited: HashSet<&'graph T>,
}

impl<'graph, T> Iterator for ReverseTopologicalFromNodeIter<'graph, T>
impl<'graph, T> Iterator for ReverseTopologicalIter<'graph, T>
where
T: Eq + std::hash::Hash + Clone,
{
Expand All @@ -178,7 +197,7 @@ where
break current;
}
ReverseTopologicalPass::Pre => {
if self.visited.contains(&current) {
if self.visited.contains(current) {
continue;
}

Expand Down