Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Connect Python builder to Rust #561

Merged
merged 2 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod window_args;

#[cfg(test)]
mod tests;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::{anyhow, Context};
use arrow::datatypes::{DataType, FieldRef};
Expand Down Expand Up @@ -135,7 +135,7 @@ pub(super) fn add_to_dfg(

if CastEvaluator::is_supported_fenl(from_type, to_type) {
if let FenlType::Concrete(to_type) = to_type.inner() {
return Ok(Rc::new(AstDfg::new(
return Ok(Arc::new(AstDfg::new(
dfg.add_expression(
Expression::Inst(InstKind::Cast(to_type.clone())),
smallvec![input.value()],
Expand Down Expand Up @@ -237,7 +237,7 @@ pub(super) fn add_to_dfg(
)?;
let is_new = base.is_new();
let value_type = field_type.clone().into();
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -275,7 +275,7 @@ pub(super) fn add_to_dfg(
let agg_input_op = dfg.operation(agg_input.value());
let tick_input = smallvec![agg_input_op];
let tick_node = dfg.add_operation(Operation::Tick(behavior), tick_input)?;
let tick_node = Rc::new(AstDfg::new(
let tick_node = Arc::new(AstDfg::new(
tick_node,
tick_node,
FenlType::Concrete(DataType::Boolean),
Expand Down Expand Up @@ -508,7 +508,7 @@ pub(super) fn add_to_dfg(
// Add cast operations as necessary
let args: Vec<_> = izip!(arguments, instantiated_types)
.map(|(arg, expected_type)| -> anyhow::Result<_> {
let ast_dfg = Rc::new(AstDfg::new(
let ast_dfg = Arc::new(AstDfg::new(
cast_if_needed(dfg, arg.value(), arg.value_type(), &expected_type)?,
arg.is_new(),
expected_type,
Expand Down Expand Up @@ -766,7 +766,7 @@ fn add_literal(
location: Location,
) -> anyhow::Result<AstDfgRef> {
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
32 changes: 21 additions & 11 deletions crates/sparrow-compiler/src/ast_to_dfg/ast_dfg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::{Arc, Mutex};

use egg::Id;
use sparrow_plan::GroupId;
Expand All @@ -11,22 +10,22 @@ use crate::time_domain::TimeDomain;
///
/// We may have multiple references to the same AstDfg node, so this allows us
/// to clone shallow references rather than deeply copy.
pub type AstDfgRef = Rc<AstDfg>;
pub type AstDfgRef = Arc<AstDfg>;

/// Various information produced for each AST node during the conversion to the
/// DFG.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct AstDfg {
/// Reference to the step containing the values of the AST node.
///
/// Wrapped in a `RefCell` to allow mutability during
/// pruning/simplification.
pub(crate) value: RefCell<Id>,
pub(crate) value: Mutex<Id>,
/// Reference to the step containing the "is_new" bits of the AST node.
///
/// Wrapped in a `RefCell` to allow mutability during
/// pruning/simplification.
pub(crate) is_new: RefCell<Id>,
pub(crate) is_new: Mutex<Id>,
/// Type of `value` produced.
value_type: FenlType,
/// Which entity grouping the node is associated with (if any).
Expand Down Expand Up @@ -77,8 +76,8 @@ impl AstDfg {
};

AstDfg {
value: RefCell::new(value),
is_new: RefCell::new(is_new),
value: Mutex::new(value),
is_new: Mutex::new(is_new),
value_type,
grouping,
time_domain,
Expand All @@ -87,12 +86,23 @@ impl AstDfg {
}
}

pub(crate) fn value(&self) -> Id {
*self.value.borrow()
pub fn equivalent(&self, other: &AstDfg) -> bool {
// This is quite correct -- we should lock everything and then compare.
// But, this is a temporary hack for the Python builder.
self.value() == other.value()
&& self.is_new() == other.is_new()
&& self.value_type == other.value_type
&& self.grouping == other.grouping
&& self.time_domain == other.time_domain
&& self.location == other.location
}

pub fn value(&self) -> Id {
*self.value.lock().unwrap()
}

pub(crate) fn is_new(&self) -> Id {
*self.is_new.borrow()
*self.is_new.lock().unwrap()
}

pub fn value_type(&self) -> &FenlType {
Expand Down
8 changes: 4 additions & 4 deletions crates/sparrow-compiler/src/ast_to_dfg/record_ops_to_dfg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::rc::Rc;
use std::sync::Arc;

use anyhow::Context;
use arrow::datatypes::{DataType, Field, FieldRef};
Expand Down Expand Up @@ -137,7 +137,7 @@ pub(super) fn record_to_dfg(
// Create the value after the fields since this takes ownership of the names.
let value = dfg.add_expression(Expression::Inst(InstKind::Record), instruction_args)?;

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -250,7 +250,7 @@ pub(super) fn extend_record_to_dfg(
TimeDomain::error()
});

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down Expand Up @@ -377,7 +377,7 @@ pub(super) fn select_remove_fields(
let value = dfg.add_expression(Expression::Inst(InstKind::Record), record_args)?;
let value_type = FenlType::Concrete(DataType::Struct(result_fields.into()));

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
record.is_new(),
value_type,
Expand Down
3 changes: 1 addition & 2 deletions crates/sparrow-compiler/src/data_context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::Context;
Expand Down Expand Up @@ -423,7 +422,7 @@ impl TableInfo {
let value_type = DataType::Struct(self.schema().fields().clone());
let value_type = FenlType::Concrete(value_type);

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
12 changes: 6 additions & 6 deletions crates/sparrow-compiler/src/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod simplification;
mod step_kind;
mod useless_transforms;

use std::rc::Rc;
use std::sync::Arc;

pub(crate) use analysis::*;
use anyhow::Context;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Default for Dfg {
// Preemptively create a single error node, allowing for shallow
// clones of the reference.
let error_id = graph.add(DfgLang::new(StepKind::Error, smallvec![]));
let error_node = Rc::new(AstDfg::new(
let error_node = Arc::new(AstDfg::new(
error_id,
error_id,
FenlType::Error,
Expand Down Expand Up @@ -487,12 +487,12 @@ impl Dfg {
});
self.env.foreach_value(|node| {
let old_value = old_graph.find(node.value());
node.value
.replace_with(|_| mapping.get(&old_value).copied().unwrap_or(new_error));
let new_value = mapping.get(&old_value).copied().unwrap_or(new_error);
*node.value.lock().unwrap() = new_value;

let old_is_new = old_graph.find(node.is_new());
node.is_new
.replace_with(|_| mapping.get(&old_is_new).copied().unwrap_or(new_error));
let new_is_new = mapping.get(&old_is_new).copied().unwrap_or(new_error);
*node.is_new.lock().unwrap() = new_is_new;
});
self.graph = new_graph;
Ok(new_output)
Expand Down
10 changes: 9 additions & 1 deletion crates/sparrow-compiler/src/diagnostics/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<'a> std::fmt::Debug for DiagnosticCollector<'a> {
pub struct CollectedDiagnostic {
code: DiagnosticCode,
formatted: String,
pub message: String,
}

impl CollectedDiagnostic {
Expand Down Expand Up @@ -119,9 +120,11 @@ impl<'a> DiagnosticCollector<'a> {
self.collected.push(CollectedDiagnostic {
code: DiagnosticCode::FailedToReport,
formatted: "Failed to report diagnostic".to_owned(),
message: "Failed to report diagnostic".to_owned(),
});
return;
};
let message = diagnostic.message.clone();
let formatted = match String::from_utf8(buffer.into_inner()) {
Ok(formatted) => formatted,
Err(err) => {
Expand All @@ -132,12 +135,17 @@ impl<'a> DiagnosticCollector<'a> {
self.collected.push(CollectedDiagnostic {
code: DiagnosticCode::FailedToReport,
formatted: "Failed to report diagnostic".to_owned(),
message,
});
return;
}
};

let diagnostic = CollectedDiagnostic { code, formatted };
let diagnostic = CollectedDiagnostic {
code,
formatted,
message,
};

match code.severity() {
Severity::Bug | Severity::Error => {
Expand Down
6 changes: 3 additions & 3 deletions crates/sparrow-compiler/src/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) mod resolve_arguments;
mod slice_analysis;

use std::collections::BTreeSet;
use std::rc::Rc;
use std::sync::Arc;

use anyhow::anyhow;
use arrow::datatypes::{DataType, TimeUnit};
Expand Down Expand Up @@ -352,7 +352,7 @@ fn create_changed_since_time_node(dfg: &mut Dfg) -> anyhow::Result<AstDfgRef> {
)?;
let value_type = FenlType::Concrete(DataType::Timestamp(TimeUnit::Nanosecond, None));
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand All @@ -375,7 +375,7 @@ fn create_final_at_time_time_node(dfg: &mut Dfg) -> anyhow::Result<AstDfgRef> {
)?;
let value_type = FenlType::Concrete(DataType::Timestamp(TimeUnit::Nanosecond, None));
let is_new = dfg.add_literal(false)?;
Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
4 changes: 2 additions & 2 deletions crates/sparrow-compiler/src/functions/function.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::rc::Rc;
use std::str::FromStr;
use std::sync::Arc;

use egg::{Subst, Var};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -224,7 +224,7 @@ impl Function {
self.time_domain_check
.check_args(location, diagnostics, args, data_context)?;

Ok(Rc::new(AstDfg::new(
Ok(Arc::new(AstDfg::new(
value,
is_new,
value_type,
Expand Down
22 changes: 14 additions & 8 deletions crates/sparrow-compiler/src/query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use uuid::Uuid;
/// Kaskada query builder.
#[derive(Default)]
pub struct QueryBuilder {
data_context: DataContext,
pub data_context: DataContext,
dfg: Dfg,
}

Expand All @@ -28,8 +28,8 @@ pub enum Error {
Invalid,
#[display(fmt = "no function named '{name}': nearest matches are {nearest:?}")]
NoSuchFunction { name: String, nearest: Vec<String> },
#[display(fmt = "errors during construction")]
Errors,
#[display(fmt = "{}", "_0.iter().join(\"\n\")")]
Errors(Vec<String>),
}

impl error_stack::Context for Error {}
Expand Down Expand Up @@ -106,7 +106,13 @@ impl QueryBuilder {
.change_context(Error::Invalid)?;

if diagnostics.num_errors() > 0 {
Err(Error::Errors)?
let errors = diagnostics
.finish()
.into_iter()
.filter(|diagnostic| diagnostic.is_error())
.map(|diagnostic| diagnostic.message)
.collect();
Err(Error::Errors(errors))?
} else {
Ok(result)
}
Expand All @@ -128,12 +134,12 @@ impl QueryBuilder {
pub fn add_expr(
&mut self,
function: &str,
args: &[AstDfgRef],
args: Vec<AstDfgRef>,
) -> error_stack::Result<AstDfgRef, Error> {
let (op, args) = match function {
"fieldref" => {
assert_eq!(args.len(), 2);
let (base, name) = args.iter().cloned().collect_tuple().unwrap();
let (base, name) = args.into_iter().collect_tuple().unwrap();

let name = self.dfg.string_literal(name.value()).expect("literal name");

Expand Down Expand Up @@ -182,7 +188,7 @@ impl QueryBuilder {
let has_vararg = args.len() > function.signature().arg_names().len();
let args = Resolved::new(
function.signature().arg_names().into(),
args.iter().cloned().map(Located::builder).collect(),
args.into_iter().map(Located::builder).collect(),
has_vararg,
);
(op, args)
Expand Down Expand Up @@ -224,7 +230,7 @@ mod tests {
.add_literal(Literal::StringLiteral("a".to_owned()))
.unwrap();
let field_ref = query_builder
.add_expr("fieldref", &[table, field_name])
.add_expr("fieldref", vec![table, field_name])
.unwrap();

assert_eq!(
Expand Down
Loading
Loading