Skip to content

Commit

Permalink
feat: call built-ins as predicate (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
arendjr authored Oct 24, 2024
1 parent 4642258 commit a1b312b
Show file tree
Hide file tree
Showing 20 changed files with 346 additions and 234 deletions.
103 changes: 95 additions & 8 deletions crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use grit_pattern_matcher::{
constant::Constant,
context::ExecContext,
pattern::{
get_absolute_file_name, CallBuiltIn, CallbackPattern, JoinFn, LazyBuiltIn, Pattern,
ResolvedPattern, ResolvedSnippet, State,
get_absolute_file_name, get_file_name, CallBuiltIn, CallbackPattern, JoinFn, LazyBuiltIn,
Pattern, ResolvedPattern, ResolvedSnippet, State,
},
};
use grit_util::{AnalysisLogs, CodeRange, Language};
use grit_util::{AnalysisLogBuilder, AnalysisLogs, CodeRange, Language};
use itertools::Itertools;
use rand::prelude::SliceRandom;
use rand::Rng;
Expand Down Expand Up @@ -47,6 +47,7 @@ pub struct BuiltInFunction {
pub name: &'static str,
pub params: Vec<&'static str>,
pub(crate) func: Box<CallableFn>,
pub(crate) position: BuiltInFunctionPosition,
}

impl BuiltInFunction {
Expand All @@ -61,7 +62,22 @@ impl BuiltInFunction {
}

pub fn new(name: &'static str, params: Vec<&'static str>, func: Box<CallableFn>) -> Self {
Self { name, params, func }
Self {
name,
params,
func,
position: BuiltInFunctionPosition::Pattern,
}
}

pub fn as_predicate(mut self) -> Self {
self.position = BuiltInFunctionPosition::Predicate;
self
}

pub fn as_predicate_or_pattern(mut self) -> Self {
self.position = BuiltInFunctionPosition::Both;
self
}
}

Expand All @@ -74,6 +90,23 @@ impl std::fmt::Debug for BuiltInFunction {
}
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum BuiltInFunctionPosition {
Pattern,
Predicate,
Both,
}

impl BuiltInFunctionPosition {
pub fn is_pattern(&self) -> bool {
matches!(self, Self::Pattern | Self::Both)
}

pub fn is_predicate(&self) -> bool {
matches!(self, Self::Predicate | Self::Both)
}
}

pub struct BuiltIns {
built_ins: Vec<BuiltInFunction>,
callbacks: Vec<Box<CallbackFn>>,
Expand Down Expand Up @@ -108,10 +141,7 @@ impl BuiltIns {
let params = &built_ins.built_ins[index].params;
let mut pattern_params = Vec::with_capacity(args.len());
for param in params.iter() {
match args.remove(&(lang.metavariable_prefix().to_owned() + param)) {
Some(p) => pattern_params.push(Some(p)),
None => pattern_params.push(None),
}
pattern_params.push(args.remove(&(lang.metavariable_prefix().to_owned() + param)));
}
Ok(CallBuiltIn::new(index, name, pattern_params))
}
Expand Down Expand Up @@ -179,6 +209,8 @@ impl BuiltIns {
BuiltInFunction::new("shuffle", vec!["list"], Box::new(shuffle_fn)),
BuiltInFunction::new("random", vec!["floor", "ceiling"], Box::new(random_fn)),
BuiltInFunction::new("split", vec!["string", "separator"], Box::new(split_fn)),
BuiltInFunction::new("log", vec!["message", "variable"], Box::new(log_fn))
.as_predicate_or_pattern(),
]
.into()
}
Expand Down Expand Up @@ -546,3 +578,58 @@ fn ai_fn_placeholder<'a>(
) -> Result<MarzanoResolvedPattern<'a>> {
bail!("AI features are not supported in your GritQL distribution. Please upgrade to the Enterprise version to use AI features.")
}

fn log_fn<'a>(
args: &'a [Option<Pattern<MarzanoQueryContext>>],
context: &'a MarzanoContext<'a>,
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> Result<MarzanoResolvedPattern<'a>> {
let mut message = args[0]
.as_ref()
.map(|message| {
MarzanoResolvedPattern::from_pattern(message, state, context, logs)
.and_then(|resolved| resolved.text(&state.files, context.language))
.map(|user_message| format!("{user_message}\n"))
})
.transpose()?
.unwrap_or_default();

let mut log_builder = AnalysisLogBuilder::default();
let file = get_file_name(state, context.language())?;
#[allow(clippy::unnecessary_cast)]
log_builder.level(441 as u16).file(file);

if let Some(Some(Pattern::Variable(variable))) = args.get(1) {
let var = state.trace_var_mut(variable);
let var_content = &state.bindings[var.try_scope().unwrap() as usize]
.last()
.unwrap()[var.try_index().unwrap() as usize];
let value = var_content.value.as_ref();

let src = value
.map(|v| {
v.text(&state.files, context.language())
.map(|s| s.to_string())
})
.unwrap_or(Ok("Variable has no source".to_string()))?;
log_builder.source(src);

let node = value.and_then(|v| v.get_last_binding());
// todo add support for other types of bindings
if let Some(node) = node {
if let Some(position) = node.position(context.language()) {
log_builder.range(position);
}
if let Some(syntax_tree) = node.get_sexp() {
log_builder.syntax_tree(syntax_tree);
}
} else {
message.push_str("attempted to log a non-node binding, such bindings don't have syntax tree or range\n")
}
}
log_builder.message(message);
logs.push(log_builder.build()?);

Ok(MarzanoResolvedPattern::Constant(Constant::Boolean(true)))
}
1 change: 0 additions & 1 deletion crates/core/src/marzano_resolved_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ impl<'a> ResolvedPattern<'a, MarzanoQueryContext> for MarzanoResolvedPattern<'a>
| Pattern::Underscore
| Pattern::AstLeafNode(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down
14 changes: 10 additions & 4 deletions crates/core/src/optimizer/hoist_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ pub fn extract_filename_pattern<Q: QueryContext>(
}
Pattern::Some(some) => extract_filename_pattern(&some.pattern),

Pattern::Log(_) => Ok(Some(Pattern::Top)),
Pattern::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

Pattern::Add(add) => {
let Some(lhs) = extract_filename_pattern(&add.lhs)? else {
Expand Down Expand Up @@ -253,7 +255,9 @@ impl<Q: QueryContext> FilenamePatternExtractor<Q> for Predicate<Q> {
}

Predicate::Rewrite(rw) => extract_filename_pattern(&rw.left),
Predicate::Log(_) => Ok(Some(Pattern::Top)),
Predicate::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

// If we hit a leaf predicate that is *not* a match, stop traversing - it is always true
Predicate::True => Ok(Some(Pattern::Top)),
Expand Down Expand Up @@ -282,7 +286,10 @@ impl<Q: QueryContext> FilenamePatternExtractor<Q> for Predicate<Q> {
}

// These are more complicated, implement carefully
Predicate::Call(_) | Predicate::Not(_) | Predicate::Equal(_) => Ok(None),
Predicate::Call(_)
| Predicate::CallBuiltIn(_)
| Predicate::Not(_)
| Predicate::Equal(_) => Ok(None),
}
}
}
Expand Down Expand Up @@ -335,7 +342,6 @@ pub(crate) fn is_safe_to_hoist<Q: QueryContext>(pattern: &Pattern<Q>) -> Result<
| Pattern::Dynamic(_)
| Pattern::Variable(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Within(_)
| Pattern::After(_)
Expand Down
14 changes: 10 additions & 4 deletions crates/core/src/optimizer/hoist_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ fn extract_pattern_text<Q: QueryContext>(pattern: &Pattern<Q>) -> Result<Option<
| Pattern::Dynamic(_)
| Pattern::Variable(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down Expand Up @@ -171,7 +170,9 @@ pub fn extract_body_pattern<Q: QueryContext>(
}
Pattern::Some(some) => extract_body_pattern(&some.pattern, matching_body),

Pattern::Log(_) => Ok(Some(Pattern::Top)),
Pattern::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

Pattern::Add(add) => {
let Some(lhs) = extract_body_pattern(&add.lhs, matching_body)? else {
Expand Down Expand Up @@ -368,7 +369,9 @@ impl<Q: QueryContext> BodyPatternExtractor<Q> for Predicate<Q> {
}

Predicate::Rewrite(rw) => extract_body_pattern(&rw.left, false),
Predicate::Log(_) => Ok(Some(Pattern::Top)),
Predicate::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

// If we hit a leaf predicate that is *not* a match, stop traversing - it is always true
Predicate::True => Ok(Some(Pattern::Top)),
Expand Down Expand Up @@ -397,7 +400,10 @@ impl<Q: QueryContext> BodyPatternExtractor<Q> for Predicate<Q> {
}

// These are more complicated, implement carefully
Predicate::Call(_) | Predicate::Not(_) | Predicate::Equal(_) => Ok(None),
Predicate::Call(_)
| Predicate::CallBuiltIn(_)
| Predicate::Not(_)
| Predicate::Equal(_) => Ok(None),
}
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/core/src/pattern_compiler/accumulate_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ impl NodeCompiler for AccumulateCompiler {
| Pattern::BooleanConstant(_)
| Pattern::CodeSnippet(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down
4 changes: 0 additions & 4 deletions crates/core/src/pattern_compiler/auto_wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ pub fn is_sequential(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down Expand Up @@ -185,7 +184,6 @@ pub(crate) fn should_autowrap(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down Expand Up @@ -279,7 +277,6 @@ fn extract_limit_pattern(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down Expand Up @@ -354,7 +351,6 @@ pub fn should_wrap_in_file(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down
Loading

0 comments on commit a1b312b

Please sign in to comment.