diff --git a/crates/turbo-static/.gitignore b/crates/turbo-static/.gitignore new file mode 100644 index 0000000000000..32d96908cdc6b --- /dev/null +++ b/crates/turbo-static/.gitignore @@ -0,0 +1,2 @@ +call_resolver.bincode +graph.cypherl diff --git a/crates/turbo-static/Cargo.toml b/crates/turbo-static/Cargo.toml new file mode 100644 index 0000000000000..699c00f64f88c --- /dev/null +++ b/crates/turbo-static/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "turbo-static" +version = "0.1.0" +edition = "2021" +license = "MPL-2.0" + +[dependencies] +bincode = "1.3.3" +clap = { workspace = true, features = ["derive"] } +ctrlc = "3.4.4" +ignore = "0.4.22" +itertools.workspace = true +lsp-server = "0.7.6" +lsp-types = "0.95.1" +proc-macro2 = { workspace = true, features = ["span-locations"] } +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +serde_path_to_error = "0.1.16" +syn = { version = "2", features = ["parsing", "full", "visit", "extra-traits"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +tracing.workspace = true +walkdir = "2.5.0" + +[lints] +workspace = true diff --git a/crates/turbo-static/readme.md b/crates/turbo-static/readme.md new file mode 100644 index 0000000000000..4ad86f1490410 --- /dev/null +++ b/crates/turbo-static/readme.md @@ -0,0 +1,33 @@ +# Turbo Static + +Leverages rust-analyzer to build a complete view into the static dependency +graph for your turbo tasks project. + +## How it works + +- find all occurences of #[turbo_tasks::function] across all the packages you + want to query +- for each of the tasks we find, query rust analyzer to see which tasks call + them +- apply some very basis control flow analysis to determine whether the call is + made 1 time, 0/1 times, or 0+ times, corresponding to direct calls, + conditionals, or for loops +- produce a cypher file that can be loaded into a graph database to query the + static dependency graph + +## Usage + +This uses an in memory persisted database to cache rust-analyzer queries. +To reset the cache, pass the `--reindex` flag. Running will produce a +`graph.cypherl` file which can be loaded into any cypher-compatible database. + +```bash +# pass in the root folders you want to analyze. the system will recursively +# parse all rust code looking for turbo tasks functions +cargo run --release -- ../../../turbo ../../../next.js +# now you can load graph.cypherl into your database of choice, such as neo4j +docker run \ + --publish=7474:7474 --publish=7687:7687 \ + --volume=$HOME/neo4j/data:/data \ + neo4j +``` diff --git a/crates/turbo-static/src/call_resolver.rs b/crates/turbo-static/src/call_resolver.rs new file mode 100644 index 0000000000000..aec97607a9438 --- /dev/null +++ b/crates/turbo-static/src/call_resolver.rs @@ -0,0 +1,165 @@ +use std::{collections::HashMap, fs::OpenOptions, path::PathBuf}; + +use crate::{lsp_client::RAClient, Identifier, IdentifierReference}; + +/// A wrapper around a rust-analyzer client that can resolve call references. +/// This is quite expensive so we cache the results in an on-disk key-value +/// store. +pub struct CallResolver<'a> { + client: &'a mut RAClient, + state: HashMap>, + path: Option, +} + +/// On drop, serialize the state to disk +impl<'a> Drop for CallResolver<'a> { + fn drop(&mut self) { + let file = OpenOptions::new() + .create(true) + .truncate(false) + .write(true) + .open(self.path.as_ref().unwrap()) + .unwrap(); + bincode::serialize_into(file, &self.state).unwrap(); + } +} + +impl<'a> CallResolver<'a> { + pub fn new(client: &'a mut RAClient, path: Option) -> Self { + // load bincode-encoded HashMap from path + let state = path + .as_ref() + .and_then(|path| { + let file = OpenOptions::new() + .create(true) + .truncate(false) + .read(true) + .write(true) + .open(path) + .unwrap(); + let reader = std::io::BufReader::new(file); + bincode::deserialize_from::<_, HashMap>>( + reader, + ) + .map_err(|e| { + tracing::warn!("failed to load existing cache, restarting"); + e + }) + .ok() + }) + .unwrap_or_default(); + Self { + client, + state, + path, + } + } + + pub fn cached_count(&self) -> usize { + self.state.len() + } + + pub fn cleared(mut self) -> Self { + // delete file if exists and clear state + self.state = Default::default(); + if let Some(path) = self.path.as_ref() { + std::fs::remove_file(path).unwrap(); + } + self + } + + pub fn resolve(&mut self, ident: &Identifier) -> Vec { + if let Some(data) = self.state.get(ident) { + tracing::info!("skipping {}", ident); + return data.to_owned(); + }; + + tracing::info!("checking {}", ident); + + let mut count = 0; + let _response = loop { + let Some(response) = self.client.request(lsp_server::Request { + id: 1.into(), + method: "textDocument/prepareCallHierarchy".to_string(), + params: serde_json::to_value(&lsp_types::CallHierarchyPrepareParams { + text_document_position_params: lsp_types::TextDocumentPositionParams { + position: ident.range.start, + text_document: lsp_types::TextDocumentIdentifier { + uri: lsp_types::Url::from_file_path(&ident.path).unwrap(), + }, + }, + work_done_progress_params: lsp_types::WorkDoneProgressParams { + work_done_token: Some(lsp_types::ProgressToken::String( + "prepare".to_string(), + )), + }, + }) + .unwrap(), + }) else { + tracing::warn!("RA server shut down"); + return vec![]; + }; + + if let Some(Some(value)) = response.result.as_ref().map(|r| r.as_array()) { + if !value.is_empty() { + break value.to_owned(); + } + count += 1; + } + + // textDocument/prepareCallHierarchy will sometimes return an empty array so try + // at most 5 times + if count > 5 { + tracing::warn!("discovered isolated task {}", ident); + break vec![]; + } + + std::thread::sleep(std::time::Duration::from_secs(1)); + }; + + // callHierarchy/incomingCalls + let Some(response) = self.client.request(lsp_server::Request { + id: 1.into(), + method: "callHierarchy/incomingCalls".to_string(), + params: serde_json::to_value(lsp_types::CallHierarchyIncomingCallsParams { + partial_result_params: lsp_types::PartialResultParams::default(), + item: lsp_types::CallHierarchyItem { + name: ident.name.to_owned(), + kind: lsp_types::SymbolKind::FUNCTION, + data: None, + tags: None, + detail: None, + uri: lsp_types::Url::from_file_path(&ident.path).unwrap(), + range: ident.range, + selection_range: ident.range, + }, + work_done_progress_params: lsp_types::WorkDoneProgressParams { + work_done_token: Some(lsp_types::ProgressToken::String("prepare".to_string())), + }, + }) + .unwrap(), + }) else { + tracing::warn!("RA server shut down"); + return vec![]; + }; + + let links = if let Some(e) = response.error { + tracing::warn!("unable to resolve {}: {:?}", ident, e); + vec![] + } else { + let response: Result, _> = + serde_path_to_error::deserialize(response.result.unwrap()); + + response + .unwrap() + .into_iter() + .map(|i| i.into()) + .collect::>() + }; + + tracing::debug!("links: {:?}", links); + + self.state.insert(ident.to_owned(), links.clone()); + links + } +} diff --git a/crates/turbo-static/src/identifier.rs b/crates/turbo-static/src/identifier.rs new file mode 100644 index 0000000000000..c92a3da7bb5d2 --- /dev/null +++ b/crates/turbo-static/src/identifier.rs @@ -0,0 +1,95 @@ +use std::{fs, path::PathBuf}; + +use lsp_types::{CallHierarchyIncomingCall, CallHierarchyItem, Range}; + +/// A task that references another, with the range of the reference +#[derive(Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize, Clone, Debug)] +pub struct IdentifierReference { + pub identifier: Identifier, + pub references: Vec, // the places where this identifier is used +} + +/// identifies a task by its file, and range in the file +#[derive(Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize, Clone)] +pub struct Identifier { + pub path: String, + // technically you can derive this from the name and range but it's easier to just store it + pub name: String, + // post_transform_name: Option, + pub range: lsp_types::Range, +} + +impl Identifier { + /// check the span matches and the text matches + /// + /// `same_location` is used to check if the location of the identifier is + /// the same as the other + pub fn equals_ident(&self, other: &syn::Ident, match_location: bool) -> bool { + *other == self.name + && (!match_location + || (self.range.start.line == other.span().start().line as u32 + && self.range.start.character == other.span().start().column as u32)) + } + + /// We cannot use `item.name` here in all cases as, during testing, the name + /// does not always align with the exact text in the range. + fn get_name(item: &CallHierarchyItem) -> String { + // open file, find range inside, extract text + let file = fs::read_to_string(item.uri.path()).unwrap(); + let start = item.selection_range.start; + let end = item.selection_range.end; + file.lines() + .nth(start.line as usize) + .unwrap() + .chars() + .skip(start.character as usize) + .take(end.character as usize - start.character as usize) + .collect() + } +} + +impl From<(PathBuf, syn::Ident)> for Identifier { + fn from((path, ident): (PathBuf, syn::Ident)) -> Self { + Self { + path: path.display().to_string(), + name: ident.to_string(), + // post_transform_name: None, + range: Range { + start: lsp_types::Position { + line: ident.span().start().line as u32 - 1, + character: ident.span().start().column as u32, + }, + end: lsp_types::Position { + line: ident.span().end().line as u32 - 1, + character: ident.span().end().column as u32, + }, + }, + } + } +} + +impl From for IdentifierReference { + fn from(item: CallHierarchyIncomingCall) -> Self { + Self { + identifier: Identifier { + name: Identifier::get_name(&item.from), + // post_transform_name: Some(item.from.name), + path: item.from.uri.path().to_owned(), + range: item.from.selection_range, + }, + references: item.from_ranges, + } + } +} + +impl std::fmt::Debug for Identifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + +impl std::fmt::Display for Identifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}#{}", self.path, self.range.start.line, self.name,) + } +} diff --git a/crates/turbo-static/src/lsp_client.rs b/crates/turbo-static/src/lsp_client.rs new file mode 100644 index 0000000000000..25d29a7efd26d --- /dev/null +++ b/crates/turbo-static/src/lsp_client.rs @@ -0,0 +1,161 @@ +use std::{path::PathBuf, process, sync::mpsc}; + +use lsp_server::Message; + +/// An LSP client for Rust Analyzer (RA) that launches it as a subprocess. +pub struct RAClient { + /// Handle to the client + handle: process::Child, + sender: Option>, + receiver: Option>, +} + +impl RAClient { + /// Create a new LSP client for Rust Analyzer. + pub fn new() -> Self { + let stdin = process::Stdio::piped(); + let stdout = process::Stdio::piped(); + let stderr = process::Stdio::inherit(); + + let child = process::Command::new("rust-analyzer") + .stdin(stdin) + .stdout(stdout) + .stderr(stderr) + // .env("RA_LOG", "info") + .env("RUST_BACKTRACE", "1") + .spawn() + .expect("Failed to start RA LSP server"); + Self { + handle: child, + sender: None, + receiver: None, + } + } + + pub fn start(&mut self, folders: &[PathBuf]) { + let stdout = self.handle.stdout.take().unwrap(); + let mut stdin = self.handle.stdin.take().unwrap(); + + let (writer_sender, writer_receiver) = mpsc::sync_channel::(0); + _ = std::thread::spawn(move || { + writer_receiver + .into_iter() + .try_for_each(|it| it.write(&mut stdin)) + }); + + let (reader_sender, reader_receiver) = mpsc::sync_channel::(0); + _ = std::thread::spawn(move || { + let mut reader = std::io::BufReader::new(stdout); + while let Ok(Some(msg)) = Message::read(&mut reader) { + reader_sender + .send(msg) + .expect("receiver was dropped, failed to send a message"); + } + }); + + self.sender = Some(writer_sender); + self.receiver = Some(reader_receiver); + + let workspace_paths = folders + .iter() + .map(|p| std::fs::canonicalize(p).unwrap()) + .map(|p| lsp_types::WorkspaceFolder { + name: p.file_name().unwrap().to_string_lossy().to_string(), + uri: lsp_types::Url::from_file_path(p).unwrap(), + }) + .collect::>(); + + _ = self.request(lsp_server::Request { + id: 1.into(), + method: "initialize".to_string(), + params: serde_json::to_value(lsp_types::InitializeParams { + workspace_folders: Some(workspace_paths), + process_id: Some(std::process::id()), + capabilities: lsp_types::ClientCapabilities { + workspace: Some(lsp_types::WorkspaceClientCapabilities { + workspace_folders: Some(true), + ..Default::default() + }), + ..Default::default() + }, + work_done_progress_params: lsp_types::WorkDoneProgressParams { + work_done_token: Some(lsp_types::ProgressToken::String("prepare".to_string())), + }, + // we use workspace_folders so root_path and root_uri can be + // empty + ..Default::default() + }) + .unwrap(), + }); + + self.notify(lsp_server::Notification { + method: "initialized".to_string(), + params: serde_json::to_value(lsp_types::InitializedParams {}).unwrap(), + }); + } + + /// Send an LSP request to the server. This returns an option + /// in the case of an error such as the server being shut down + /// from pressing `Ctrl+C`. + pub fn request(&mut self, message: lsp_server::Request) -> Option { + tracing::debug!("sending {:?}", message); + self.sender + .as_mut() + .unwrap() + .send(Message::Request(message)) + .ok()?; + + loop { + match self.receiver.as_mut().unwrap().recv() { + Ok(lsp_server::Message::Response(response)) => { + tracing::debug!("received {:?}", response); + return Some(response); + } + Ok(m) => tracing::trace!("unexpected message: {:?}", m), + Err(_) => { + tracing::trace!("error receiving message"); + return None; + } + } + } + } + + pub fn notify(&mut self, message: lsp_server::Notification) { + self.sender + .as_mut() + .unwrap() + .send(Message::Notification(message)) + .expect("failed to send message"); + } +} + +impl Drop for RAClient { + fn drop(&mut self) { + if self.sender.is_some() { + let Some(resp) = self.request(lsp_server::Request { + id: 1.into(), + method: "shutdown".to_string(), + params: serde_json::to_value(()).unwrap(), + }) else { + return; + }; + + if resp.error.is_none() { + tracing::info!("shutting down RA LSP server"); + self.notify(lsp_server::Notification { + method: "exit".to_string(), + params: serde_json::to_value(()).unwrap(), + }); + self.handle + .wait() + .expect("failed to wait for RA LSP server"); + tracing::info!("shut down RA LSP server"); + } else { + tracing::error!("failed to shutdown RA LSP server: {:#?}", resp); + } + } + + self.sender = None; + self.receiver = None; + } +} diff --git a/crates/turbo-static/src/main.rs b/crates/turbo-static/src/main.rs new file mode 100644 index 0000000000000..a3eece260d62c --- /dev/null +++ b/crates/turbo-static/src/main.rs @@ -0,0 +1,303 @@ +#![feature(entry_insert)] + +use std::{ + collections::{HashMap, HashSet}, + error::Error, + fs, + path::PathBuf, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use call_resolver::CallResolver; +use clap::Parser; +use identifier::{Identifier, IdentifierReference}; +use itertools::Itertools; +use syn::visit::Visit; +use visitor::CallingStyleVisitor; + +use crate::visitor::CallingStyle; + +mod call_resolver; +mod identifier; +mod lsp_client; +mod visitor; + +#[derive(Parser)] +struct Opt { + #[clap(required = true)] + paths: Vec, + + /// reparse all files + #[clap(long)] + reparse: bool, + + /// reindex all files + #[clap(long)] + reindex: bool, +} + +fn main() -> Result<(), Box> { + tracing_subscriber::fmt::init(); + let opt = Opt::parse(); + + let mut connection = lsp_client::RAClient::new(); + connection.start(&opt.paths); + + let call_resolver = CallResolver::new(&mut connection, Some("call_resolver.bincode".into())); + let mut call_resolver = if opt.reindex { + call_resolver.cleared() + } else { + call_resolver + }; + + let halt = Arc::new(AtomicBool::new(false)); + let halt_clone = halt.clone(); + ctrlc::set_handler({ + move || { + halt_clone.store(true, Ordering::SeqCst); + } + })?; + + tracing::info!("getting tasks"); + let mut tasks = get_all_tasks(&opt.paths); + let dep_tree = resolve_tasks(&mut tasks, &mut call_resolver, halt.clone()); + let concurrency = resolve_concurrency(&tasks, &dep_tree, halt.clone()); + + write_dep_tree(&tasks, concurrency, std::path::Path::new("graph.cypherl")); + + if halt.load(Ordering::Relaxed) { + tracing::info!("ctrl-c detected, exiting"); + } + + Ok(()) +} + +/// search the given folders recursively and attempt to find all tasks inside +#[tracing::instrument(skip_all)] +fn get_all_tasks(folders: &[PathBuf]) -> HashMap> { + let mut out = HashMap::new(); + + for folder in folders { + let walker = ignore::Walk::new(folder); + for entry in walker { + let entry = entry.unwrap(); + let rs_file = if let Some(true) = entry.file_type().map(|t| t.is_file()) { + let path = entry.path(); + let ext = path.extension().unwrap_or_default(); + if ext == "rs" { + std::fs::canonicalize(path).unwrap() + } else { + continue; + } + } else { + continue; + }; + + let file = fs::read_to_string(&rs_file).unwrap(); + let lines = file.lines(); + let mut occurences = vec![]; + + tracing::debug!("processing {}", rs_file.display()); + + for ((_, line), (line_no, _)) in lines.enumerate().tuple_windows() { + if line.contains("turbo_tasks::function") { + tracing::debug!("found at {:?}:L{}", rs_file, line_no); + occurences.push(line_no + 1); + } + } + + if occurences.is_empty() { + continue; + } + + // parse the file using syn and get the span of the functions + let file = syn::parse_file(&file).unwrap(); + let occurences_count = occurences.len(); + let mut visitor = visitor::TaskVisitor::new(); + syn::visit::visit_file(&mut visitor, &file); + if visitor.results.len() != occurences_count { + tracing::warn!( + "file {:?} passed the heuristic with {:?} but the visitor found {:?}", + rs_file, + occurences_count, + visitor.results.len() + ); + } + + out.extend( + visitor + .results + .into_iter() + .map(move |(ident, tags)| ((rs_file.clone(), ident).into(), tags)), + ) + } + } + + out +} + +/// Given a list of tasks, get all the tasks that call that one +fn resolve_tasks( + tasks: &mut HashMap>, + client: &mut CallResolver, + halt: Arc, +) -> HashMap> { + tracing::info!( + "found {} tasks, of which {} cached", + tasks.len(), + client.cached_count() + ); + + let mut unresolved = tasks.keys().cloned().collect::>(); + let mut resolved = HashMap::new(); + + while let Some(top) = unresolved.iter().next().cloned() { + unresolved.remove(&top); + + let callers = client.resolve(&top); + + // add all non-task callers to the unresolved list if they are not in the + // resolved list + for caller in callers.iter() { + if !resolved.contains_key(&caller.identifier) + && !unresolved.contains(&caller.identifier) + { + tracing::debug!("adding {} to unresolved", caller.identifier); + unresolved.insert(caller.identifier.to_owned()); + } + } + resolved.insert(top.to_owned(), callers); + + if halt.load(Ordering::Relaxed) { + break; + } + } + + resolved +} + +/// given a map of tasks and functions that call it, produce a map of tasks and +/// those tasks that it calls +/// +/// returns a list of pairs with a task, the task that calls it, and the calling +/// style +fn resolve_concurrency( + task_list: &HashMap>, + dep_tree: &HashMap>, // pairs of tasks and call trees + halt: Arc, +) -> Vec<(Identifier, Identifier, CallingStyle)> { + // println!("{:?}", dep_tree); + // println!("{:#?}", task_list); + + let mut edges = vec![]; + + for (ident, references) in dep_tree { + for reference in references { + if !dep_tree.contains_key(&reference.identifier) { + // this is a task that is not in the task list + // so we can't resolve it + tracing::error!("missing task for {}: {}", ident, reference.identifier); + for task in task_list.keys() { + if task.name == reference.identifier.name { + // we found a task that is not in the task list + // so we can't resolve it + tracing::trace!("- found {}", task); + continue; + } + } + continue; + } else { + // load the source file and get the calling style + let target = IdentifierReference { + identifier: ident.clone(), + references: reference.references.clone(), + }; + let mut visitor = CallingStyleVisitor::new(target); + tracing::info!("looking for {} from {}", ident, reference.identifier); + let file = + syn::parse_file(&fs::read_to_string(&reference.identifier.path).unwrap()) + .unwrap(); + visitor.visit_file(&file); + + edges.push(( + ident.clone(), + reference.identifier.clone(), + visitor.result().unwrap_or(CallingStyle::Once), + )); + } + + if halt.load(Ordering::Relaxed) { + break; + } + } + } + + // parse each fn between parent and child and get the max calling style + + edges +} + +/// Write the dep tree into the given file using cypher syntax +fn write_dep_tree( + task_list: &HashMap>, + dep_tree: Vec<(Identifier, Identifier, CallingStyle)>, + out: &std::path::Path, +) { + use std::io::Write; + + let mut node_ids = HashMap::new(); + let mut counter = 0; + + let mut file = std::fs::File::create(out).unwrap(); + + let empty = vec![]; + + // collect all tasks as well as all intermediate nodes + // tasks come last to ensure the tags are preserved + let node_list = dep_tree + .iter() + .flat_map(|(dest, src, _)| [(src, &empty), (dest, &empty)]) + .chain(task_list) + .collect::>(); + + for (ident, tags) in node_list { + counter += 1; + + let label = if !task_list.contains_key(ident) { + "Function" + } else if tags.contains(&"fs".to_string()) || tags.contains(&"network".to_string()) { + "ImpureTask" + } else { + "Task" + }; + + _ = writeln!( + file, + "CREATE (n_{}:{} {{name: '{}', file: '{}', line: {}, tags: [{}]}})", + counter, + label, + ident.name, + ident.path, + ident.range.start.line, + tags.iter().map(|t| format!("\"{}\"", t)).join(",") + ); + node_ids.insert(ident, counter); + } + + for (dest, src, style) in &dep_tree { + let style = match style { + CallingStyle::Once => "ONCE", + CallingStyle::ZeroOrOnce => "ZERO_OR_ONCE", + CallingStyle::ZeroOrMore => "ZERO_OR_MORE", + CallingStyle::OneOrMore => "ONE_OR_MORE", + }; + + let src_id = *node_ids.get(src).unwrap(); + let dst_id = *node_ids.get(dest).unwrap(); + + _ = writeln!(file, "CREATE (n_{})-[:{}]->(n_{})", src_id, style, dst_id,); + } +} diff --git a/crates/turbo-static/src/visitor.rs b/crates/turbo-static/src/visitor.rs new file mode 100644 index 0000000000000..113c1a4e1218f --- /dev/null +++ b/crates/turbo-static/src/visitor.rs @@ -0,0 +1,275 @@ +//! A visitor that traverses the AST and collects all functions or methods that +//! are annotated with `#[turbo_tasks::function]`. + +use std::{collections::VecDeque, ops::Add}; + +use lsp_types::Range; +use syn::{visit::Visit, Expr, Meta}; + +use crate::identifier::Identifier; + +pub struct TaskVisitor { + /// the list of results as pairs of an identifier and its tags + pub results: Vec<(syn::Ident, Vec)>, +} + +impl TaskVisitor { + pub fn new() -> Self { + Self { + results: Default::default(), + } + } +} + +impl Visit<'_> for TaskVisitor { + #[tracing::instrument(skip_all)] + fn visit_item_fn(&mut self, i: &syn::ItemFn) { + if let Some(tags) = extract_tags(i.attrs.iter()) { + tracing::trace!("L{}: {}", i.sig.ident.span().start().line, i.sig.ident,); + self.results.push((i.sig.ident.clone(), tags)); + } + } + + #[tracing::instrument(skip_all)] + fn visit_impl_item_fn(&mut self, i: &syn::ImplItemFn) { + if let Some(tags) = extract_tags(i.attrs.iter()) { + tracing::trace!("L{}: {}", i.sig.ident.span().start().line, i.sig.ident,); + self.results.push((i.sig.ident.clone(), tags)); + } + } +} + +fn extract_tags<'a>(mut meta: impl Iterator) -> Option> { + meta.find_map(|a| match &a.meta { + // path has two segments, turbo_tasks and function + Meta::Path(path) if path.segments.len() == 2 => { + let first = &path.segments[0]; + let second = &path.segments[1]; + (first.ident == "turbo_tasks" && second.ident == "function").then(std::vec::Vec::new) + } + Meta::List(list) if list.path.segments.len() == 2 => { + let first = &list.path.segments[0]; + let second = &list.path.segments[1]; + if first.ident != "turbo_tasks" || second.ident != "function" { + return None; + } + + // collect ident tokens as args + let tags: Vec<_> = list + .tokens + .clone() + .into_iter() + .filter_map(|t| { + if let proc_macro2::TokenTree::Ident(ident) = t { + Some(ident.to_string()) + } else { + None + } + }) + .collect(); + + Some(tags) + } + _ => { + tracing::trace!("skipping unknown annotation"); + None + } + }) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum CallingStyle { + Once = 0b0010, + ZeroOrOnce = 0b0011, + ZeroOrMore = 0b0111, + OneOrMore = 0b0110, +} + +impl CallingStyle { + fn bitset(self) -> u8 { + self as u8 + } +} + +impl Add for CallingStyle { + type Output = Self; + + /// Add two calling styles together to determine the calling style of the + /// target function within the source function. + /// + /// Consider it as a bitset over properties. + /// - 0b000: Nothing + /// - 0b001: Zero + /// - 0b010: Once + /// - 0b011: Zero Or Once + /// - 0b100: More Than Once + /// - 0b101: Zero Or More Than Once (?) + /// - 0b110: Once Or More + /// - 0b111: Zero Or More + /// + /// Note that zero is not a valid calling style. + fn add(self, rhs: Self) -> Self { + let left = self.bitset(); + let right = rhs.bitset(); + + // we treat this as a bitset under addition + #[allow(clippy::suspicious_arithmetic_impl)] + match left | right { + 0b0010 => CallingStyle::Once, + 0b011 => CallingStyle::ZeroOrOnce, + 0b0111 => CallingStyle::ZeroOrMore, + 0b0110 => CallingStyle::OneOrMore, + // the remaining 4 (null, zero, more than once, zero or more than once) + // are unreachable because we don't detect 'zero' or 'more than once' + _ => unreachable!(), + } + } +} + +pub struct CallingStyleVisitor { + pub reference: crate::IdentifierReference, + state: VecDeque, + halt: bool, +} + +impl CallingStyleVisitor { + /// Create a new visitor that will traverse the AST and determine the + /// calling style of the target function within the source function. + pub fn new(reference: crate::IdentifierReference) -> Self { + Self { + reference, + state: Default::default(), + halt: false, + } + } + + pub fn result(self) -> Option { + self.state + .into_iter() + .map(|b| match b { + CallingStyleVisitorState::Block => CallingStyle::Once, + CallingStyleVisitorState::Loop => CallingStyle::ZeroOrMore, + CallingStyleVisitorState::If => CallingStyle::ZeroOrOnce, + CallingStyleVisitorState::Closure => CallingStyle::ZeroOrMore, + }) + .reduce(|a, b| a + b) + } +} + +#[derive(Debug, Clone, Copy)] +enum CallingStyleVisitorState { + Block, + Loop, + If, + Closure, +} + +impl Visit<'_> for CallingStyleVisitor { + fn visit_item_fn(&mut self, i: &'_ syn::ItemFn) { + self.state.push_back(CallingStyleVisitorState::Block); + syn::visit::visit_item_fn(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_impl_item_fn(&mut self, i: &'_ syn::ImplItemFn) { + self.state.push_back(CallingStyleVisitorState::Block); + syn::visit::visit_impl_item_fn(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_expr_loop(&mut self, i: &'_ syn::ExprLoop) { + self.state.push_back(CallingStyleVisitorState::Loop); + syn::visit::visit_expr_loop(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_expr_for_loop(&mut self, i: &'_ syn::ExprForLoop) { + self.state.push_back(CallingStyleVisitorState::Loop); + syn::visit::visit_expr_for_loop(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_expr_if(&mut self, i: &'_ syn::ExprIf) { + self.state.push_back(CallingStyleVisitorState::If); + syn::visit::visit_expr_if(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_expr_closure(&mut self, i: &'_ syn::ExprClosure) { + self.state.push_back(CallingStyleVisitorState::Closure); + syn::visit::visit_expr_closure(self, i); + if !self.halt { + self.state.pop_back(); + } + } + + fn visit_expr_call(&mut self, i: &'_ syn::ExprCall) { + syn::visit::visit_expr_call(self, i); + if let Expr::Path(p) = i.func.as_ref() { + if let Some(last) = p.path.segments.last() { + if is_match( + &self.reference.identifier, + &last.ident, + &self.reference.references, + ) { + self.halt = true; + } + } + } + } + + // to validate this, we first check if the name is the same and then compare it + // against any of the references we are holding + fn visit_expr_method_call(&mut self, i: &'_ syn::ExprMethodCall) { + if is_match( + &self.reference.identifier, + &i.method, + &self.reference.references, + ) { + self.halt = true; + } + + syn::visit::visit_expr_method_call(self, i); + } +} + +/// Check if some ident referenced by `check` is calling the `target` by +/// looking it up in the list of known `ranges`. +fn is_match(target: &Identifier, check: &syn::Ident, ranges: &[Range]) -> bool { + if target.equals_ident(check, false) { + let span = check.span(); + // syn is 1-indexed, range is not + for reference in ranges { + if reference.start.line != span.start().line as u32 - 1 { + continue; + } + + if reference.start.character != span.start().column as u32 { + continue; + } + + if reference.end.line != span.end().line as u32 - 1 { + continue; + } + + if reference.end.character != span.end().column as u32 { + continue; + } + + // match, just exit the visitor + return true; + } + } + + false +}