Skip to content

Commit

Permalink
feat: Connect Python builder to Rust (#561)
Browse files Browse the repository at this point in the history
This creates DFG nodes and checks their types.
    
Some extra features:
    
- Uses the pyarrow types of the expression to limit some overloads.
  In the case of `__getattr__` this prevents mistakes leading easily to
  infinite recursion. In the case of `expr[expr]` this lets us use the
  correct methods.
- Allows using python literals (`str`, `int` and `float`) as arguments
  to expressions.
- Renames `ffi` module to `_ffi` to indicate it is private.
- Add some tests fro error cases
  • Loading branch information
bjchambers authored Jul 27, 2023
1 parent 82adac4 commit 4483e13
Show file tree
Hide file tree
Showing 27 changed files with 3,619 additions and 624 deletions.
12 changes: 6 additions & 6 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod window_args;

#[cfg(test)]
mod tests;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::{anyhow, Context};
use arrow::datatypes::{DataType, FieldRef};
Expand Down Expand Up @@ -135,7 +135,7 @@ pub(super) fn add_to_dfg(

if CastEvaluator::is_supported_fenl(from_type, to_type) {
if let FenlType::Concrete(to_type) = to_type.inner() {
return Ok(Rc::new(AstDfg::new(
return Ok(Arc::new(AstDfg::new(
dfg.add_expression(
Expression::Inst(InstKind::Cast(to_type.clone())),
smallvec![input.value()],
Expand Down Expand Up @@ -237,7 +237,7 @@ pub(super) fn add_to_dfg(
)?;
let is_new = base.is_new();
let value_type = field_type.clone().into();
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -275,7 +275,7 @@ pub(super) fn add_to_dfg(
let agg_input_op = dfg.operation(agg_input.value());
let tick_input = smallvec![agg_input_op];
let tick_node = dfg.add_operation(Operation::Tick(behavior), tick_input)?;
let tick_node = Rc::new(AstDfg::new(
let tick_node = Arc::new(AstDfg::new(
tick_node,
tick_node,
FenlType::Concrete(DataType::Boolean),
Expand Down Expand Up @@ -508,7 +508,7 @@ pub(super) fn add_to_dfg(
// Add cast operations as necessary
let args: Vec<_> = izip!(arguments, instantiated_types)
.map(|(arg, expected_type)| -> anyhow::Result<_> {
let ast_dfg = Rc::new(AstDfg::new(
let ast_dfg = Arc::new(AstDfg::new(
cast_if_needed(dfg, arg.value(), arg.value_type(), &expected_type)?,
arg.is_new(),
expected_type,
Expand Down Expand Up @@ -766,7 +766,7 @@ fn add_literal(
location: Location,
) -> anyhow::Result<AstDfgRef> {
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
32 changes: 21 additions & 11 deletions crates/sparrow-compiler/src/ast_to_dfg/ast_dfg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::{Arc, Mutex};

use egg::Id;
use sparrow_plan::GroupId;
Expand All @@ -11,22 +10,22 @@ use crate::time_domain::TimeDomain;
///
/// We may have multiple references to the same AstDfg node, so this allows us
/// to clone shallow references rather than deeply copy.
pub type AstDfgRef = Rc<AstDfg>;
pub type AstDfgRef = Arc<AstDfg>;

/// Various information produced for each AST node during the conversion to the
/// DFG.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct AstDfg {
/// Reference to the step containing the values of the AST node.
///
/// Wrapped in a `RefCell` to allow mutability during
/// pruning/simplification.
pub(crate) value: RefCell<Id>,
pub(crate) value: Mutex<Id>,
/// Reference to the step containing the "is_new" bits of the AST node.
///
/// Wrapped in a `RefCell` to allow mutability during
/// pruning/simplification.
pub(crate) is_new: RefCell<Id>,
pub(crate) is_new: Mutex<Id>,
/// Type of `value` produced.
value_type: FenlType,
/// Which entity grouping the node is associated with (if any).
Expand Down Expand Up @@ -77,8 +76,8 @@ impl AstDfg {
};

AstDfg {
value: RefCell::new(value),
is_new: RefCell::new(is_new),
value: Mutex::new(value),
is_new: Mutex::new(is_new),
value_type,
grouping,
time_domain,
Expand All @@ -87,12 +86,23 @@ impl AstDfg {
}
}

pub(crate) fn value(&self) -> Id {
*self.value.borrow()
pub fn equivalent(&self, other: &AstDfg) -> bool {
// This is quite correct -- we should lock everything and then compare.
// But, this is a temporary hack for the Python builder.
self.value() == other.value()
&& self.is_new() == other.is_new()
&& self.value_type == other.value_type
&& self.grouping == other.grouping
&& self.time_domain == other.time_domain
&& self.location == other.location
}

pub fn value(&self) -> Id {
*self.value.lock().unwrap()
}

pub(crate) fn is_new(&self) -> Id {
*self.is_new.borrow()
*self.is_new.lock().unwrap()
}

pub fn value_type(&self) -> &FenlType {
Expand Down
8 changes: 4 additions & 4 deletions crates/sparrow-compiler/src/ast_to_dfg/record_ops_to_dfg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::rc::Rc;
use std::sync::Arc;

use anyhow::Context;
use arrow::datatypes::{DataType, Field, FieldRef};
Expand Down Expand Up @@ -137,7 +137,7 @@ pub(super) fn record_to_dfg(
// Create the value after the fields since this takes ownership of the names.
let value = dfg.add_expression(Expression::Inst(InstKind::Record), instruction_args)?;

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -250,7 +250,7 @@ pub(super) fn extend_record_to_dfg(
TimeDomain::error()
});

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -377,7 +377,7 @@ pub(super) fn select_remove_fields(
let value = dfg.add_expression(Expression::Inst(InstKind::Record), record_args)?;
let value_type = FenlType::Concrete(DataType::Struct(result_fields.into()));

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
record.is_new(),
value_type,
Expand Down
3 changes: 1 addition & 2 deletions crates/sparrow-compiler/src/data_context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::Context;
Expand Down Expand Up @@ -423,7 +422,7 @@ impl TableInfo {
let value_type = DataType::Struct(self.schema().fields().clone());
let value_type = FenlType::Concrete(value_type);

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
12 changes: 6 additions & 6 deletions crates/sparrow-compiler/src/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod simplification;
mod step_kind;
mod useless_transforms;

use std::rc::Rc;
use std::sync::Arc;

pub(crate) use analysis::*;
use anyhow::Context;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Default for Dfg {
// Preemptively create a single error node, allowing for shallow
// clones of the reference.
let error_id = graph.add(DfgLang::new(StepKind::Error, smallvec![]));
let error_node = Rc::new(AstDfg::new(
let error_node = Arc::new(AstDfg::new(
error_id,
error_id,
FenlType::Error,
Expand Down Expand Up @@ -487,12 +487,12 @@ impl Dfg {
});
self.env.foreach_value(|node| {
let old_value = old_graph.find(node.value());
node.value
.replace_with(|_| mapping.get(&old_value).copied().unwrap_or(new_error));
let new_value = mapping.get(&old_value).copied().unwrap_or(new_error);
*node.value.lock().unwrap() = new_value;

let old_is_new = old_graph.find(node.is_new());
node.is_new
.replace_with(|_| mapping.get(&old_is_new).copied().unwrap_or(new_error));
let new_is_new = mapping.get(&old_is_new).copied().unwrap_or(new_error);
*node.is_new.lock().unwrap() = new_is_new;
});
self.graph = new_graph;
Ok(new_output)
Expand Down
10 changes: 9 additions & 1 deletion crates/sparrow-compiler/src/diagnostics/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<'a> std::fmt::Debug for DiagnosticCollector<'a> {
pub struct CollectedDiagnostic {
code: DiagnosticCode,
formatted: String,
pub message: String,
}

impl CollectedDiagnostic {
Expand Down Expand Up @@ -119,9 +120,11 @@ impl<'a> DiagnosticCollector<'a> {
self.collected.push(CollectedDiagnostic {
code: DiagnosticCode::FailedToReport,
formatted: "Failed to report diagnostic".to_owned(),
message: "Failed to report diagnostic".to_owned(),
});
return;
};
let message = diagnostic.message.clone();
let formatted = match String::from_utf8(buffer.into_inner()) {
Ok(formatted) => formatted,
Err(err) => {
Expand All @@ -132,12 +135,17 @@ impl<'a> DiagnosticCollector<'a> {
self.collected.push(CollectedDiagnostic {
code: DiagnosticCode::FailedToReport,
formatted: "Failed to report diagnostic".to_owned(),
message,
});
return;
}
};

let diagnostic = CollectedDiagnostic { code, formatted };
let diagnostic = CollectedDiagnostic {
code,
formatted,
message,
};

match code.severity() {
Severity::Bug | Severity::Error => {
Expand Down
6 changes: 3 additions & 3 deletions crates/sparrow-compiler/src/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) mod resolve_arguments;
mod slice_analysis;

use std::collections::BTreeSet;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::anyhow;
use arrow::datatypes::{DataType, TimeUnit};
Expand Down Expand Up @@ -352,7 +352,7 @@ fn create_changed_since_time_node(dfg: &mut Dfg) -> anyhow::Result<AstDfgRef> {
)?;
let value_type = FenlType::Concrete(DataType::Timestamp(TimeUnit::Nanosecond, None));
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand All @@ -375,7 +375,7 @@ fn create_final_at_time_time_node(dfg: &mut Dfg) -> anyhow::Result<AstDfgRef> {
)?;
let value_type = FenlType::Concrete(DataType::Timestamp(TimeUnit::Nanosecond, None));
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-compiler/src/functions/function.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::rc::Rc;
use std::str::FromStr;
use std::sync::Arc;

use egg::{Subst, Var};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -224,7 +224,7 @@ impl Function {
self.time_domain_check
.check_args(location, diagnostics, args, data_context)?;

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
22 changes: 14 additions & 8 deletions crates/sparrow-compiler/src/query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use uuid::Uuid;
/// Kaskada query builder.
#[derive(Default)]
pub struct QueryBuilder {
data_context: DataContext,
pub data_context: DataContext,
dfg: Dfg,
}

Expand All @@ -28,8 +28,8 @@ pub enum Error {
Invalid,
#[display(fmt = "no function named '{name}': nearest matches are {nearest:?}")]
NoSuchFunction { name: String, nearest: Vec<String> },
#[display(fmt = "errors during construction")]
Errors,
#[display(fmt = "{}", "_0.iter().join(\"\n\")")]
Errors(Vec<String>),
}

impl error_stack::Context for Error {}
Expand Down Expand Up @@ -106,7 +106,13 @@ impl QueryBuilder {
.change_context(Error::Invalid)?;

if diagnostics.num_errors() > 0 {
Err(Error::Errors)?
let errors = diagnostics
.finish()
.into_iter()
.filter(|diagnostic| diagnostic.is_error())
.map(|diagnostic| diagnostic.message)
.collect();
Err(Error::Errors(errors))?
} else {
Ok(result)
}
Expand All @@ -128,12 +134,12 @@ impl QueryBuilder {
pub fn add_expr(
&mut self,
function: &str,
args: &[AstDfgRef],
args: Vec<AstDfgRef>,
) -> error_stack::Result<AstDfgRef, Error> {
let (op, args) = match function {
"fieldref" => {
assert_eq!(args.len(), 2);
let (base, name) = args.iter().cloned().collect_tuple().unwrap();
let (base, name) = args.into_iter().collect_tuple().unwrap();

let name = self.dfg.string_literal(name.value()).expect("literal name");

Expand Down Expand Up @@ -182,7 +188,7 @@ impl QueryBuilder {
let has_vararg = args.len() > function.signature().arg_names().len();
let args = Resolved::new(
function.signature().arg_names().into(),
args.iter().cloned().map(Located::builder).collect(),
args.into_iter().map(Located::builder).collect(),
has_vararg,
);
(op, args)
Expand Down Expand Up @@ -224,7 +230,7 @@ mod tests {
.add_literal(Literal::StringLiteral("a".to_owned()))
.unwrap();
let field_ref = query_builder
.add_expr("fieldref", &[table, field_name])
.add_expr("fieldref", vec![table, field_name])
.unwrap();

assert_eq!(
Expand Down
Loading

0 comments on commit 4483e13

Please sign in to comment.