Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: call built-ins as predicate #557

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 95 additions & 8 deletions crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use grit_pattern_matcher::{
constant::Constant,
context::ExecContext,
pattern::{
get_absolute_file_name, CallBuiltIn, CallbackPattern, JoinFn, LazyBuiltIn, Pattern,
ResolvedPattern, ResolvedSnippet, State,
get_absolute_file_name, get_file_name, CallBuiltIn, CallbackPattern, JoinFn, LazyBuiltIn,
Pattern, ResolvedPattern, ResolvedSnippet, State,
},
};
use grit_util::{AnalysisLogs, CodeRange, Language};
use grit_util::{AnalysisLogBuilder, AnalysisLogs, CodeRange, Language};
use itertools::Itertools;
use rand::prelude::SliceRandom;
use rand::Rng;
Expand Down Expand Up @@ -47,6 +47,7 @@ pub struct BuiltInFunction {
pub name: &'static str,
pub params: Vec<&'static str>,
pub(crate) func: Box<CallableFn>,
pub(crate) position: BuiltInFunctionPosition,
}

impl BuiltInFunction {
Expand All @@ -61,7 +62,22 @@ impl BuiltInFunction {
}

pub fn new(name: &'static str, params: Vec<&'static str>, func: Box<CallableFn>) -> Self {
Self { name, params, func }
Self {
name,
params,
func,
position: BuiltInFunctionPosition::Pattern,
}
}

pub fn as_predicate(mut self) -> Self {
self.position = BuiltInFunctionPosition::Predicate;
self
}

pub fn as_predicate_or_pattern(mut self) -> Self {
self.position = BuiltInFunctionPosition::Both;
self
}
}

Expand All @@ -74,6 +90,23 @@ impl std::fmt::Debug for BuiltInFunction {
}
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum BuiltInFunctionPosition {
Pattern,
Predicate,
Both,
}

impl BuiltInFunctionPosition {
pub fn is_pattern(&self) -> bool {
matches!(self, Self::Pattern | Self::Both)
}

pub fn is_predicate(&self) -> bool {
matches!(self, Self::Predicate | Self::Both)
}
}

pub struct BuiltIns {
built_ins: Vec<BuiltInFunction>,
callbacks: Vec<Box<CallbackFn>>,
Expand Down Expand Up @@ -108,10 +141,7 @@ impl BuiltIns {
let params = &built_ins.built_ins[index].params;
let mut pattern_params = Vec::with_capacity(args.len());
for param in params.iter() {
match args.remove(&(lang.metavariable_prefix().to_owned() + param)) {
Some(p) => pattern_params.push(Some(p)),
None => pattern_params.push(None),
}
pattern_params.push(args.remove(&(lang.metavariable_prefix().to_owned() + param)));
}
Ok(CallBuiltIn::new(index, name, pattern_params))
}
Expand Down Expand Up @@ -179,6 +209,8 @@ impl BuiltIns {
BuiltInFunction::new("shuffle", vec!["list"], Box::new(shuffle_fn)),
BuiltInFunction::new("random", vec!["floor", "ceiling"], Box::new(random_fn)),
BuiltInFunction::new("split", vec!["string", "separator"], Box::new(split_fn)),
BuiltInFunction::new("log", vec!["message", "variable"], Box::new(log_fn))
.as_predicate_or_pattern(),
]
.into()
}
Expand Down Expand Up @@ -546,3 +578,58 @@ fn ai_fn_placeholder<'a>(
) -> Result<MarzanoResolvedPattern<'a>> {
bail!("AI features are not supported in your GritQL distribution. Please upgrade to the Enterprise version to use AI features.")
}

fn log_fn<'a>(
args: &'a [Option<Pattern<MarzanoQueryContext>>],
context: &'a MarzanoContext<'a>,
state: &mut State<'a, MarzanoQueryContext>,
logs: &mut AnalysisLogs,
) -> Result<MarzanoResolvedPattern<'a>> {
let mut message = args[0]
.as_ref()
.map(|message| {
MarzanoResolvedPattern::from_pattern(message, state, context, logs)
.and_then(|resolved| resolved.text(&state.files, context.language))
.map(|user_message| format!("{user_message}\n"))
})
.transpose()?
.unwrap_or_default();

let mut log_builder = AnalysisLogBuilder::default();
let file = get_file_name(state, context.language())?;
#[allow(clippy::unnecessary_cast)]
log_builder.level(441 as u16).file(file);

if let Some(Some(Pattern::Variable(variable))) = args.get(1) {
let var = state.trace_var_mut(variable);
let var_content = &state.bindings[var.try_scope().unwrap() as usize]
.last()
.unwrap()[var.try_index().unwrap() as usize];
let value = var_content.value.as_ref();

let src = value
.map(|v| {
v.text(&state.files, context.language())
.map(|s| s.to_string())
})
.unwrap_or(Ok("Variable has no source".to_string()))?;
log_builder.source(src);

let node = value.and_then(|v| v.get_last_binding());
// todo add support for other types of bindings
if let Some(node) = node {
if let Some(position) = node.position(context.language()) {
log_builder.range(position);
}
if let Some(syntax_tree) = node.get_sexp() {
log_builder.syntax_tree(syntax_tree);
}
} else {
message.push_str("attempted to log a non-node binding, such bindings don't have syntax tree or range\n")
}
}
log_builder.message(message);
logs.push(log_builder.build()?);

Ok(MarzanoResolvedPattern::Constant(Constant::Boolean(true)))
}
1 change: 0 additions & 1 deletion crates/core/src/marzano_resolved_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ impl<'a> ResolvedPattern<'a, MarzanoQueryContext> for MarzanoResolvedPattern<'a>
| Pattern::Underscore
| Pattern::AstLeafNode(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down
14 changes: 10 additions & 4 deletions crates/core/src/optimizer/hoist_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ pub fn extract_filename_pattern<Q: QueryContext>(
}
Pattern::Some(some) => extract_filename_pattern(&some.pattern),

Pattern::Log(_) => Ok(Some(Pattern::Top)),
Pattern::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

Pattern::Add(add) => {
let Some(lhs) = extract_filename_pattern(&add.lhs)? else {
Expand Down Expand Up @@ -253,7 +255,9 @@ impl<Q: QueryContext> FilenamePatternExtractor<Q> for Predicate<Q> {
}

Predicate::Rewrite(rw) => extract_filename_pattern(&rw.left),
Predicate::Log(_) => Ok(Some(Pattern::Top)),
Predicate::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

// If we hit a leaf predicate that is *not* a match, stop traversing - it is always true
Predicate::True => Ok(Some(Pattern::Top)),
Expand Down Expand Up @@ -282,7 +286,10 @@ impl<Q: QueryContext> FilenamePatternExtractor<Q> for Predicate<Q> {
}

// These are more complicated, implement carefully
Predicate::Call(_) | Predicate::Not(_) | Predicate::Equal(_) => Ok(None),
Predicate::Call(_)
| Predicate::CallBuiltIn(_)
| Predicate::Not(_)
| Predicate::Equal(_) => Ok(None),
}
}
}
Expand Down Expand Up @@ -335,7 +342,6 @@ pub(crate) fn is_safe_to_hoist<Q: QueryContext>(pattern: &Pattern<Q>) -> Result<
| Pattern::Dynamic(_)
| Pattern::Variable(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Within(_)
| Pattern::After(_)
Expand Down
14 changes: 10 additions & 4 deletions crates/core/src/optimizer/hoist_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ fn extract_pattern_text<Q: QueryContext>(pattern: &Pattern<Q>) -> Result<Option<
| Pattern::Dynamic(_)
| Pattern::Variable(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down Expand Up @@ -171,7 +170,9 @@ pub fn extract_body_pattern<Q: QueryContext>(
}
Pattern::Some(some) => extract_body_pattern(&some.pattern, matching_body),

Pattern::Log(_) => Ok(Some(Pattern::Top)),
Pattern::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

Pattern::Add(add) => {
let Some(lhs) = extract_body_pattern(&add.lhs, matching_body)? else {
Expand Down Expand Up @@ -368,7 +369,9 @@ impl<Q: QueryContext> BodyPatternExtractor<Q> for Predicate<Q> {
}

Predicate::Rewrite(rw) => extract_body_pattern(&rw.left, false),
Predicate::Log(_) => Ok(Some(Pattern::Top)),
Predicate::CallBuiltIn(call_built_in) if call_built_in.name == "log" => {
Ok(Some(Pattern::Top))
}

// If we hit a leaf predicate that is *not* a match, stop traversing - it is always true
Predicate::True => Ok(Some(Pattern::Top)),
Expand Down Expand Up @@ -397,7 +400,10 @@ impl<Q: QueryContext> BodyPatternExtractor<Q> for Predicate<Q> {
}

// These are more complicated, implement carefully
Predicate::Call(_) | Predicate::Not(_) | Predicate::Equal(_) => Ok(None),
Predicate::Call(_)
| Predicate::CallBuiltIn(_)
| Predicate::Not(_)
| Predicate::Equal(_) => Ok(None),
}
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/core/src/pattern_compiler/accumulate_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ impl NodeCompiler for AccumulateCompiler {
| Pattern::BooleanConstant(_)
| Pattern::CodeSnippet(_)
| Pattern::Rewrite(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down
4 changes: 0 additions & 4 deletions crates/core/src/pattern_compiler/auto_wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ pub fn is_sequential(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Contains(_)
| Pattern::Includes(_)
Expand Down Expand Up @@ -185,7 +184,6 @@ pub(crate) fn should_autowrap(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down Expand Up @@ -279,7 +277,6 @@ fn extract_limit_pattern(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down Expand Up @@ -354,7 +351,6 @@ pub fn should_wrap_in_file(
| Pattern::Dynamic(_)
| Pattern::CodeSnippet(_)
| Pattern::Variable(_)
| Pattern::Log(_)
| Pattern::Range(_)
| Pattern::Includes(_)
| Pattern::Within(_)
Expand Down
Loading
Loading