diff --git a/Cargo.toml b/Cargo.toml index f2f385d..dbdcfe5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ rayon = "1.7.0" regex = "1.8.2" termcolor = "1.2.0" tree-sitter = "0.20.10" +tree-sitter-javascript = "0.20.0" tree-sitter-rust = "0.20.3" tree-sitter-typescript = "0.20.2" diff --git a/src/language.rs b/src/language.rs index 1fedfd2..9d7f114 100644 --- a/src/language.rs +++ b/src/language.rs @@ -1,10 +1,13 @@ +use std::{collections::HashMap, ffi::OsStr, path::Path}; + use clap::ValueEnum; use tree_sitter::Language; -#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum, Hash)] pub enum SupportedLanguageName { Rust, Typescript, + Javascript, } impl SupportedLanguageName { @@ -12,13 +15,16 @@ impl SupportedLanguageName { match self { Self::Rust => Box::new(get_rust_language()), Self::Typescript => Box::new(get_typescript_language()), + Self::Javascript => Box::new(get_javascript_language()), } } } pub trait SupportedLanguage { fn language(&self) -> Language; + fn name(&self) -> SupportedLanguageName; fn name_for_ignore_select(&self) -> &'static str; + fn extensions(&self) -> Vec<&'static str>; } pub struct SupportedLanguageRust; @@ -28,9 +34,17 @@ impl SupportedLanguage for SupportedLanguageRust { tree_sitter_rust::language() } + fn name(&self) -> SupportedLanguageName { + SupportedLanguageName::Rust + } + fn name_for_ignore_select(&self) -> &'static str { "rust" } + + fn extensions(&self) -> Vec<&'static str> { + vec!["rs"] + } } pub fn get_rust_language() -> SupportedLanguageRust { @@ -44,11 +58,68 @@ impl SupportedLanguage for SupportedLanguageTypescript { tree_sitter_typescript::language_tsx() } + fn name(&self) -> SupportedLanguageName { + SupportedLanguageName::Typescript + } + fn name_for_ignore_select(&self) -> &'static str { "ts" } + + fn extensions(&self) -> Vec<&'static str> { + vec!["ts", "tsx"] + } } pub fn get_typescript_language() -> SupportedLanguageTypescript { SupportedLanguageTypescript } + +pub struct SupportedLanguageJavascript; + +impl SupportedLanguage for SupportedLanguageJavascript { + fn language(&self) -> Language { + tree_sitter_javascript::language() + } + + fn name(&self) -> SupportedLanguageName { + SupportedLanguageName::Javascript + } + + fn name_for_ignore_select(&self) -> &'static str { + "js" + } + + fn extensions(&self) -> Vec<&'static str> { + vec!["js", "jsx", "vue", "cjs", "mjs"] + } +} + +pub fn get_javascript_language() -> SupportedLanguageJavascript { + SupportedLanguageJavascript +} + +pub fn get_all_supported_languages() -> HashMap> { + [ + ( + SupportedLanguageName::Rust, + Box::new(get_rust_language()) as Box, + ), + ( + SupportedLanguageName::Typescript, + Box::new(get_typescript_language()) as Box, + ), + ( + SupportedLanguageName::Javascript, + Box::new(get_javascript_language()) as Box, + ), + ] + .into() +} + +pub fn maybe_supported_language_from_path(path: &Path) -> Option> { + let extension = path.extension().and_then(OsStr::to_str)?; + get_all_supported_languages() + .into_values() + .find(|language| language.extensions().contains(&extension)) +} diff --git a/src/lib.rs b/src/lib.rs index c6e977e..192d77f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,19 @@ use std::{ cell::RefCell, + collections::HashMap, fs, path::{Path, PathBuf}, + process, rc::Rc, - sync::{mpsc, mpsc::Receiver, Arc}, + sync::{ + atomic::{AtomicU32, Ordering}, + mpsc, + mpsc::Receiver, + Arc, Mutex, + }, thread, thread::JoinHandle, + time::Duration, }; use clap::Parser; @@ -25,9 +33,12 @@ mod macros; mod plugin; mod treesitter; -use language::{SupportedLanguage, SupportedLanguageName}; +use language::{ + get_all_supported_languages, maybe_supported_language_from_path, SupportedLanguage, + SupportedLanguageName, +}; use plugin::get_loaded_filter; -use treesitter::{get_matches, get_query}; +use treesitter::{get_matches, maybe_get_query}; #[derive(Parser)] pub struct Args { @@ -37,7 +48,7 @@ pub struct Args { #[arg(short, long = "capture")] pub capture_name: Option, #[arg(short, long, value_enum)] - pub language: SupportedLanguageName, + pub language: Option, #[arg(short, long)] pub filter: Option, #[arg(short = 'a', long)] @@ -83,25 +94,100 @@ fn get_output_mode(args: &Args) -> OutputMode { } } +struct MaybeInitializedCaptureIndex(AtomicU32); + +impl MaybeInitializedCaptureIndex { + const UNINITIALIZED: u32 = u32::MAX; + const FAILED: u32 = u32::MAX - 1; + + fn mark_failed(&self) -> bool { + loop { + let existing_value = self.0.load(Ordering::Relaxed); + if existing_value == Self::FAILED { + return false; + } + let did_store = self.0.compare_exchange( + existing_value, + Self::FAILED, + Ordering::Relaxed, + Ordering::Relaxed, + ); + if did_store.is_ok() { + return true; + } + } + } + + pub fn get(&self) -> Result, ()> { + let loaded = self.0.load(Ordering::Relaxed); + match loaded { + loaded if loaded == Self::UNINITIALIZED => Ok(None), + loaded if loaded == Self::FAILED => Err(()), + loaded => Ok(Some(loaded)), + } + } + + pub fn get_or_initialize(&self, query: &Query, capture_name: Option<&str>) -> Result { + if let Some(already_initialized) = self.get()? { + return Ok(already_initialized); + } + let capture_index = match capture_name { + None => 0, + Some(capture_name) => { + let capture_index = query.capture_index_for_name(capture_name); + if capture_index.is_none() { + let did_mark_failed = self.mark_failed(); + if did_mark_failed { + fail(&format!("invalid capture name '{}'", capture_name)); + } else { + // whichever other thread "won the race" will have called this fail() + // so we'll be getting killed shortly? + thread::sleep(Duration::from_millis(100_000)); + } + } + capture_index.unwrap() + } + }; + self.set(capture_index); + Ok(capture_index) + } + + fn set(&self, capture_index: u32) { + self.0.store(capture_index, Ordering::Relaxed); + } +} + +impl Default for MaybeInitializedCaptureIndex { + fn default() -> Self { + Self(AtomicU32::new(Self::UNINITIALIZED)) + } +} + pub fn run(args: Args) { let query_source = match args.query_args.path_to_query_file.as_ref() { Some(path_to_query_file) => fs::read_to_string(path_to_query_file).unwrap(), None => args.query_args.query_source.clone().unwrap(), }; - let supported_language = args.language.get_language(); - let language = supported_language.language(); - let query = Arc::new(get_query(&query_source, language)); - let capture_index = args.capture_name.as_ref().map_or(0, |capture_name| { - query - .capture_index_for_name(capture_name) - .unwrap_or_else(|| panic!("Unknown capture name: `{}`", capture_name)) - }); + let specified_supported_language = args.language.map(|language| language.get_language()); + let query_or_failure_by_language: Mutex>>> = + Default::default(); + let capture_index = MaybeInitializedCaptureIndex::default(); let output_mode = get_output_mode(&args); let buffer_writer = BufferWriter::stdout(ColorChoice::Never); - get_project_file_walker(&*supported_language, &args.use_paths()) + get_project_file_walker(specified_supported_language.as_deref(), &args.use_paths()) .into_parallel_iterator() .for_each(|project_file_dir_entry| { + let language = maybe_supported_language_from_path(project_file_dir_entry.path()) + .expect("Walker should've been pre-filtered to just supported file types"); + let query = return_if_none!(get_and_cache_query_for_language( + &query_source, + &query_or_failure_by_language, + &*language, + )); + let capture_index = return_if_none!(capture_index + .get_or_initialize(&query, args.capture_name.as_deref()) + .ok()); let printer = get_printer(&buffer_writer, output_mode); let mut printer = printer.borrow_mut(); let path = @@ -110,7 +196,7 @@ pub fn run(args: Args) { let matcher = TreeSitterMatcher::new( &query, capture_index, - language, + language.language(), args.filter.clone(), args.filter_arg.clone(), ); @@ -122,6 +208,38 @@ pub fn run(args: Args) { .unwrap(); buffer_writer.print(printer.get_mut()).unwrap(); }); + + error_if_no_successful_query_parsing(&query_or_failure_by_language); +} + +fn error_if_no_successful_query_parsing( + query_or_failure_by_language: &Mutex>>>, +) { + let query_or_failure_by_language = query_or_failure_by_language.lock().unwrap(); + if !query_or_failure_by_language + .values() + .any(|query| query.is_some()) + { + fail("invalid query"); + } +} + +fn fail(message: &str) -> ! { + eprintln!("error: {message}"); + process::exit(1); +} + +fn get_and_cache_query_for_language( + query_source: &str, + query_or_failure_by_language: &Mutex>>>, + language: &dyn SupportedLanguage, +) -> Option> { + query_or_failure_by_language + .lock() + .unwrap() + .entry(language.name()) + .or_insert_with(|| maybe_get_query(query_source, language.language()).map(Arc::new)) + .clone() } type Printer = Standard; @@ -244,16 +362,22 @@ impl Iterator for WalkParallelIterator { } } -fn get_project_file_walker(language: &dyn SupportedLanguage, paths: &[PathBuf]) -> WalkParallel { +fn get_project_file_walker( + language: Option<&dyn SupportedLanguage>, + paths: &[PathBuf], +) -> WalkParallel { assert!(!paths.is_empty()); let mut builder = WalkBuilder::new(&paths[0]); - builder.types( - TypesBuilder::new() - .add_defaults() - .select(language.name_for_ignore_select()) - .build() - .unwrap(), - ); + let mut types_builder = TypesBuilder::new(); + types_builder.add_defaults(); + if let Some(language) = language { + types_builder.select(language.name_for_ignore_select()); + } else { + for language in get_all_supported_languages().values() { + types_builder.select(language.name_for_ignore_select()); + } + } + builder.types(types_builder.build().unwrap()); for path in &paths[1..] { builder.add(path); } diff --git a/src/macros.rs b/src/macros.rs index 2f30638..34f34a0 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -5,3 +5,15 @@ macro_rules! regex { RE.get_or_init(|| regex::Regex::new($re).unwrap()) }}; } + +#[macro_export] +macro_rules! return_if_none { + ($expr:expr $(,)?) => { + match $expr { + None => { + return; + } + Some(expr) => expr, + } + }; +} diff --git a/src/treesitter/mod.rs b/src/treesitter/mod.rs index f2848ba..1a7ed60 100644 --- a/src/treesitter/mod.rs +++ b/src/treesitter/mod.rs @@ -11,8 +11,8 @@ pub fn get_parser(language: Language) -> Parser { parser } -pub fn get_query(source: &str, language: Language) -> Query { - Query::new(language, source).unwrap() +pub fn maybe_get_query(source: &str, language: Language) -> Option { + Query::new(language, source).ok() } pub fn get_matches( diff --git a/tests/fixtures/mixed_project/javascript_src/index.js b/tests/fixtures/mixed_project/javascript_src/index.js new file mode 100644 index 0000000..9d58dd8 --- /dev/null +++ b/tests/fixtures/mixed_project/javascript_src/index.js @@ -0,0 +1 @@ +const js_foo = () => {} diff --git a/tests/fixtures/mixed_project/rust_src/lib.rs b/tests/fixtures/mixed_project/rust_src/lib.rs new file mode 100644 index 0000000..8f3b7ef --- /dev/null +++ b/tests/fixtures/mixed_project/rust_src/lib.rs @@ -0,0 +1 @@ +fn foo() {} diff --git a/tests/fixtures/mixed_project/typescript_src/index.tsx b/tests/fixtures/mixed_project/typescript_src/index.tsx new file mode 100644 index 0000000..ca3076a --- /dev/null +++ b/tests/fixtures/mixed_project/typescript_src/index.tsx @@ -0,0 +1 @@ +const foo = () => {} diff --git a/tests/fixtures/rust_project/function-itemz.scm b/tests/fixtures/rust_project/function-itemz.scm new file mode 100644 index 0000000..3f28c96 --- /dev/null +++ b/tests/fixtures/rust_project/function-itemz.scm @@ -0,0 +1 @@ +(function_itemz) @function_item diff --git a/tests/output.rs b/tests/output.rs index a4d7a41..bedf27b 100644 --- a/tests/output.rs +++ b/tests/output.rs @@ -88,6 +88,25 @@ fn do_sorted_lines_match(actual_output: &str, expected_output: &str) -> bool { actual_lines == expected_lines } +fn assert_failure_output(fixture_dir_name: &str, command_and_output: &str) { + let CommandAndOutput { + mut command_line_args, + output, + } = parse_command_and_output(command_and_output); + let command_name = command_line_args.remove(0); + Command::cargo_bin(command_name) + .unwrap() + .args(command_line_args) + .current_dir(get_fixture_dir_path_from_name(fixture_dir_name)) + .assert() + .failure() + // .stderr(predicate::function(|stderr: &str| { + // println!("stderr: {stderr:#?}, output: {output:#?}"); + // stderr == output + // })); + .stderr(predicate::eq(output)); +} + #[test] fn test_query_inline() { assert_sorted_output( @@ -229,3 +248,94 @@ fn test_specify_multiple_files() { "#, ); } + +#[test] +fn test_invalid_query_inline() { + assert_failure_output( + "rust_project", + r#" + $ tree-sitter-grep --query-source '(function_itemz) @function_item' --language rust + error: invalid query + "#, + ); +} + +#[test] +fn test_invalid_query_file() { + assert_failure_output( + "rust_project", + r#" + $ tree-sitter-grep --query-file ./function-itemz.scm --language rust + error: invalid query + "#, + ); +} + +#[test] +fn test_no_query_specified() { + assert_failure_output( + "rust_project", + r#" + $ tree-sitter-grep --language rust + error: the following required arguments were not provided: + <--query-file |--query-source > + + Usage: tree-sitter-grep --language <--query-file |--query-source > [PATHS]... + + For more information, try '--help'. + "#, + ); +} + +#[test] +fn test_invalid_capture_name() { + assert_failure_output( + "rust_project", + r#" + $ tree-sitter-grep --query-source '(function_item) @function_item' --language rust --capture function_itemz + error: invalid capture name 'function_itemz' + "#, + ); +} + +#[test] +fn test_auto_language_single_known_language_encountered() { + assert_sorted_output( + "rust_project", + r#" + $ tree-sitter-grep --query-source '(function_item) @function_item' + src/helpers.rs:1:pub fn helper() {} + src/lib.rs:3:pub fn add(left: usize, right: usize) -> usize { + src/lib.rs:4: left + right + src/lib.rs:5:} + src/lib.rs:12: fn it_works() { + src/lib.rs:13: let result = add(2, 2); + src/lib.rs:14: assert_eq!(result, 4); + src/lib.rs:15: } + src/stop.rs:1:fn stop_it() {} + "#, + ); +} + +#[test] +fn test_auto_language_multiple_parseable_languages() { + assert_sorted_output( + "mixed_project", + r#" + $ tree-sitter-grep --query-source '(arrow_function) @arrow_function' + javascript_src/index.js:1:const js_foo = () => {} + typescript_src/index.tsx:1:const foo = () => {} + "#, + ); +} + +#[test] +fn test_auto_language_single_parseable_languages() { + assert_sorted_output( + "mixed_project", + r#" + $ tree-sitter-grep --query-source '(function_item) @function_item' + rust_src/lib.rs:1:fn foo() {} + "#, + ); +}