From 89aecd15e3f7d2fd900611c042a3b2aee8392f71 Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Mon, 3 Jun 2024 18:45:32 +0200 Subject: [PATCH] Split `ColumnOp` into one with row indices and one with `FieldName` & other enabled changes (#1207) * 1. Split ColumnOp into ColumnOp & FieldOp, former storing ColId 2. Shrink SqlAst to 80 bytes, so it can be passed in registers 3. Store end-result Header in IndexSemiJoin 4. Remove operational use of Header in ColumnOp & build_query 5. Simplify RowRef::{get, project, project_owned} * Make parts of build_query actually infallible. 1. Make IndexSemiJoin::filter infallible. 2. Make ColumnOp::compare and friends infallible. 3. Make RowRef::{get, project, project_owned} infallible. * Make RelOps::next itself infallible * 1. with_select{_cmp}: ensure type safety o query exec cannot panic 2. Document RelValue::{get, read_or_take_column, project_owned} 3. Refactor optimize_select 4. Ensure in optimize_select that conditions are merged with preceding selects * remove RelOps::{head, row_count}; head is redundant & row_count is useless * remove Relation trait; it does not carry its weight * make build_query infallible * simplify IndexSemiJoin, make it slightly less branchy * simplify try_index_join * split IndexSemiJoin into Left & Right parts * move get_field_pos to test code * move test version of build_query to test code --- crates/bench/benches/subscription.rs | 12 +- crates/core/src/db/cursor.rs | 4 +- .../locking_tx_datastore/state_view.rs | 6 +- crates/core/src/db/mod.rs | 1 - crates/core/src/error.rs | 2 + crates/core/src/host/instance_env.rs | 30 +- crates/core/src/sql/ast.rs | 68 +- crates/core/src/sql/compiler.rs | 89 +-- crates/core/src/sql/execute.rs | 13 +- .../core/src/subscription/execution_unit.rs | 79 +- .../subscription/module_subscription_actor.rs | 2 +- .../module_subscription_manager.rs | 11 +- crates/core/src/subscription/query.rs | 29 +- crates/core/src/subscription/subscription.rs | 103 ++- crates/core/src/vm.rs | 357 +++++---- crates/data-structures/src/map.rs | 4 + crates/sats/src/db/error.rs | 7 +- crates/sats/src/relation.rs | 100 +-- crates/vm/src/errors.rs | 2 - crates/vm/src/eval.rs | 196 ++--- crates/vm/src/expr.rs | 736 ++++++++++-------- crates/vm/src/iterators.rs | 23 +- crates/vm/src/program.rs | 43 - crates/vm/src/rel_ops.rs | 145 +--- crates/vm/src/relation.rs | 76 +- 25 files changed, 1017 insertions(+), 1121 deletions(-) diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index 4e005c75df..c9ce3fd36e 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -104,7 +104,7 @@ fn eval(c: &mut Criterion) { let query = compile_read_only_query(&raw.db, &tx, sql).unwrap(); let query: ExecutionSet = query.into(); let ctx = &ExecutionContext::subscribe(raw.db.address(), SlowQueryConfig::default()); - b.iter(|| drop(black_box(query.eval(ctx, Protocol::Binary, &raw.db, &tx).unwrap()))) + b.iter(|| drop(black_box(query.eval(ctx, Protocol::Binary, &raw.db, &tx)))) }); }; @@ -140,10 +140,7 @@ fn eval(c: &mut Criterion) { let query = ExecutionSet::from_iter(query_lhs.into_iter().chain(query_rhs)); let tx = &tx.into(); - b.iter(|| { - let out = query.eval_incr(ctx_incr, &raw.db, tx, &update).unwrap(); - black_box(out); - }) + b.iter(|| drop(black_box(query.eval_incr(ctx_incr, &raw.db, tx, &update)))) }); // To profile this benchmark for 30s @@ -161,10 +158,7 @@ fn eval(c: &mut Criterion) { let query: ExecutionSet = query.into(); let tx = &tx.into(); - b.iter(|| { - let out = query.eval_incr(ctx_incr, &raw.db, tx, &update).unwrap(); - black_box(out); - }) + b.iter(|| drop(black_box(query.eval_incr(ctx_incr, &raw.db, tx, &update)))); }); // To profile this benchmark for 30s diff --git a/crates/core/src/db/cursor.rs b/crates/core/src/db/cursor.rs index e7732060df..deec7af762 100644 --- a/crates/core/src/db/cursor.rs +++ b/crates/core/src/db/cursor.rs @@ -6,12 +6,12 @@ use spacetimedb_sats::AlgebraicValue; /// Common wrapper for relational iterators that work like cursors. pub struct TableCursor<'a> { - pub table: DbTable, + pub table: &'a DbTable, pub iter: Iter<'a>, } impl<'a> TableCursor<'a> { - pub fn new(table: DbTable, iter: Iter<'a>) -> Result { + pub fn new(table: &'a DbTable, iter: Iter<'a>) -> Result { Ok(Self { table, iter }) } } diff --git a/crates/core/src/db/datastore/locking_tx_datastore/state_view.rs b/crates/core/src/db/datastore/locking_tx_datastore/state_view.rs index 0658d9e32e..a4d75e1be3 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/state_view.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/state_view.rs @@ -286,10 +286,10 @@ impl<'a> Iterator for Iter<'a> { // // As a result, in MVCC, this branch will need to check if the `row_ref` // also exists in the `tx_state.insert_tables` and ensure it is yielded only once. - if !self + if self .tx_state - .map(|tx_state| tx_state.is_deleted(table_id, row_ref.pointer())) - .unwrap_or(false) + .filter(|tx_state| tx_state.is_deleted(table_id, row_ref.pointer())) + .is_none() { // There either are no state changes for the current tx (`None`), // or there are, but `row_id` specifically has not been changed. diff --git a/crates/core/src/db/mod.rs b/crates/core/src/db/mod.rs index 02addbc311..0a91361984 100644 --- a/crates/core/src/db/mod.rs +++ b/crates/core/src/db/mod.rs @@ -1,4 +1,3 @@ -pub mod cursor; pub mod datastore; pub mod db_metrics; pub mod relational_db; diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 68c9925750..03d1447e74 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -320,6 +320,8 @@ pub enum NodesError { SystemName(Box), #[error("internal db error: {0}")] Internal(#[source] Box), + #[error(transparent)] + BadQuery(#[from] RelationError), #[error("invalid index type: {0}")] BadIndexType(u8), } diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index 16822ebe4f..0373294a3c 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -19,9 +19,9 @@ use spacetimedb_lib::operator::OpQuery; use spacetimedb_lib::ProductValue; use spacetimedb_primitives::{ColId, ColListBuilder, TableId}; use spacetimedb_sats::db::def::{IndexDef, IndexType}; -use spacetimedb_sats::relation::{FieldExpr, FieldName}; +use spacetimedb_sats::relation::FieldName; use spacetimedb_sats::Typespace; -use spacetimedb_vm::expr::{ColumnOp, NoInMemUsed, QueryExpr}; +use spacetimedb_vm::expr::{FieldExpr, FieldOp, NoInMemUsed, QueryExpr}; #[derive(Clone)] pub struct InstanceEnv { @@ -320,27 +320,27 @@ impl InstanceEnv { ) -> Result>, NodesError> { use spacetimedb_lib::filter; - fn filter_to_column_op(table_id: TableId, filter: filter::Expr) -> ColumnOp { + fn filter_to_column_op(table_id: TableId, filter: filter::Expr) -> FieldOp { match filter { filter::Expr::Cmp(filter::Cmp { op, args: CmpArgs { lhs_field, rhs }, - }) => ColumnOp::Cmp { + }) => FieldOp::Cmp { op: OpQuery::Cmp(op), - lhs: Box::new(ColumnOp::Field(FieldExpr::Name(FieldName::new( + lhs: Box::new(FieldOp::Field(FieldExpr::Name(FieldName::new( table_id, lhs_field.into(), )))), - rhs: Box::new(ColumnOp::Field(match rhs { + rhs: Box::new(FieldOp::Field(match rhs { filter::Rhs::Field(rhs_field) => FieldExpr::Name(FieldName::new(table_id, rhs_field.into())), filter::Rhs::Value(rhs_value) => FieldExpr::Value(rhs_value), })), }, - filter::Expr::Logic(filter::Logic { lhs, op, rhs }) => ColumnOp::Cmp { - op: OpQuery::Logic(op), - lhs: Box::new(filter_to_column_op(table_id, *lhs)), - rhs: Box::new(filter_to_column_op(table_id, *rhs)), - }, + filter::Expr::Logic(filter::Logic { lhs, op, rhs }) => FieldOp::new( + OpQuery::Logic(op), + filter_to_column_op(table_id, *lhs), + filter_to_column_op(table_id, *rhs), + ), filter::Expr::Unary(_) => todo!("unary operations are not yet supported"), } } @@ -363,7 +363,7 @@ impl InstanceEnv { // TODO(Centril): consider caching from `filter: &[u8] -> query: QueryExpr`. let query = QueryExpr::new(schema.as_ref()) - .with_select(filter_to_column_op(table_id, filter)) + .with_select(filter_to_column_op(table_id, filter))? .optimize(&|table_id, table_name| stdb.row_count(table_id, table_name)); // TODO(Centril): Conditionally dump the `query` to a file and compare against integration test. @@ -371,11 +371,11 @@ impl InstanceEnv { let tx: TxMode = tx.into(); // SQL queries can never reference `MemTable`s, so pass in an empty set. - let mut query = build_query(ctx, stdb, &tx, &query, &mut NoInMemUsed)?; + let mut query = build_query(ctx, stdb, &tx, &query, &mut NoInMemUsed); // write all rows and flush at row boundaries. - let query_iter = std::iter::from_fn(|| query.next().transpose()); - let chunks = itertools::process_results(query_iter, |it| ChunkedWriter::collect_iter(it))?; + let query_iter = std::iter::from_fn(|| query.next()); + let chunks = ChunkedWriter::collect_iter(query_iter); Ok(chunks) } } diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index f1bf583922..b66ff1ecff 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -1,14 +1,14 @@ use crate::config::ReadConfigOption; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::{DBError, PlanError}; -use spacetimedb_data_structures::map::HashMap; -use spacetimedb_primitives::{ColList, ConstraintKind, Constraints}; +use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; +use spacetimedb_primitives::{ColId, ColList, ConstraintKind, Constraints}; use spacetimedb_sats::db::def::{ColumnDef, ConstraintDef, TableDef, TableSchema}; use spacetimedb_sats::db::error::RelationError; -use spacetimedb_sats::relation::{FieldExpr, FieldName}; +use spacetimedb_sats::relation::{ColExpr, FieldName}; use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; use spacetimedb_vm::errors::ErrorVm; -use spacetimedb_vm::expr::{ColumnOp, DbType, Expr}; +use spacetimedb_vm::expr::{DbType, Expr, FieldExpr, FieldOp}; use spacetimedb_vm::operator::{OpCmp, OpLogic, OpQuery}; use spacetimedb_vm::ops::parse::{parse, parse_simple_enum}; use sqlparser::ast::{ @@ -103,12 +103,12 @@ pub enum Column { /// The list of expressions for `SELECT expr1, expr2...` determining what data to extract. #[derive(Debug, Clone)] pub struct Selection { - pub(crate) clause: ColumnOp, + pub(crate) clause: FieldOp, } impl Selection { - pub fn with_cmp(op: OpQuery, lhs: ColumnOp, rhs: ColumnOp) -> Self { - let cmp = ColumnOp::new(op, lhs, rhs); + pub fn with_cmp(op: OpQuery, lhs: FieldOp, rhs: FieldOp) -> Self { + let cmp = FieldOp::new(op, lhs, rhs); Selection { clause: cmp } } } @@ -257,17 +257,17 @@ pub fn find_field<'a>( pub enum SqlAst { Select { from: From, - project: Vec, + project: Box<[Column]>, selection: Option, }, Insert { table: Arc, - columns: Vec, - values: Vec>, + columns: Box<[ColId]>, + values: Box<[Box<[ColExpr]>]>, }, Update { table: Arc, - assignments: HashMap, + assignments: IntMap, selection: Option, }, Delete { @@ -275,7 +275,7 @@ pub enum SqlAst { selection: Option, }, CreateTable { - table: TableDef, + table: Box, }, Drop { name: String, @@ -350,8 +350,8 @@ fn compile_expr_value<'a>( tables: impl Clone + Iterator, field: Option<&'a AlgebraicType>, of: SqlExpr, -) -> Result { - Ok(ColumnOp::Field(match of { +) -> Result { + Ok(FieldOp::Field(match of { SqlExpr::Identifier(name) => FieldExpr::Name(find_field(tables, &name.value)?.0), SqlExpr::CompoundIdentifier(ident) => { let col_name = compound_ident(&ident); @@ -373,7 +373,7 @@ fn compile_expr_value<'a>( SqlExpr::BinaryOp { left, op, right } => { let (op, lhs, rhs) = compile_bin_op(tables, op, left, right)?; - return Ok(ColumnOp::new(op, lhs, rhs)); + return Ok(FieldOp::new(op, lhs, rhs)); } SqlExpr::Nested(x) => { return compile_expr_value(tables, field, *x); @@ -388,7 +388,7 @@ fn compile_expr_value<'a>( fn compile_expr_field(table: &From, field: Option<&AlgebraicType>, of: SqlExpr) -> Result { match compile_expr_value(table.iter_tables(), field, of)? { - ColumnOp::Field(field) => Ok(field), + FieldOp::Field(field) => Ok(field), x => Err(PlanError::Unsupported { feature: format!("Complex expression {x} on insert..."), }), @@ -422,7 +422,7 @@ fn compile_bin_op<'a>( op: BinaryOperator, lhs: Box, rhs: Box, -) -> Result<(OpQuery, ColumnOp, ColumnOp), PlanError> { +) -> Result<(OpQuery, FieldOp, FieldOp), PlanError> { let op: OpQuery = match op { BinaryOperator::Gt => OpCmp::Gt.into(), BinaryOperator::Lt => OpCmp::Lt.into(), @@ -532,8 +532,8 @@ fn compile_from(db: &RelationalDB, tx: &T, from: &[TableWith let tables = base.iter_tables().chain([&*join]); let expr = compile_expr_value(tables, None, x.clone())?; match expr { - ColumnOp::Field(_) => {} - ColumnOp::Cmp { op, lhs, rhs } => { + FieldOp::Field(_) => {} + FieldOp::Cmp { op, lhs, rhs } => { let op = match op { OpQuery::Cmp(op) => op, OpQuery::Logic(op) => { @@ -543,7 +543,7 @@ fn compile_from(db: &RelationalDB, tx: &T, from: &[TableWith } }; let (lhs, rhs) = match (*lhs, *rhs) { - (ColumnOp::Field(FieldExpr::Name(lhs)), ColumnOp::Field(FieldExpr::Name(rhs))) => { + (FieldOp::Field(FieldExpr::Name(lhs)), FieldOp::Field(FieldExpr::Name(rhs))) => { (lhs, rhs) } (lhs, rhs) => { @@ -597,7 +597,7 @@ fn compile_select_item(from: &From, select_item: SelectItem) -> Result { let value = compile_expr_value(from.iter_tables(), None, expr)?; match value { - ColumnOp::Field(value) => match value { + FieldOp::Field(value) => match value { FieldExpr::Name(_) => Err(PlanError::Unsupported { feature: "Should not be an identifier in Expr::Value".to_string(), }), @@ -626,11 +626,13 @@ fn compile_select_item(from: &From, select_item: SelectItem) -> Result(db: &RelationalDB, tx: &T, select: Select) -> Result { let from = compile_from(db, tx, &select.from)?; + // SELECT ... let mut project = Vec::with_capacity(select.projection.len()); for select_item in select.projection { project.push(compile_select_item(&from, select_item)?); } + let project = project.into(); let selection = compile_where(&from, select.selection)?; @@ -711,21 +713,21 @@ fn compile_insert( .map(|x| { table .find_field(&format!("{}.{}", &table.root.table_name, x)) - .map(|(f, _)| f) + .map(|(f, _)| f.col) }) - .collect::, _>>()?; + .collect::, _>>()?; let mut values = Vec::with_capacity(data.rows.len()); - for x in &data.rows { let mut row = Vec::with_capacity(x.len()); for (pos, v) in x.iter().enumerate() { let field_ty = table.root.get_column(pos).map(|col| &col.col_type); - row.push(compile_expr_field(&table, field_ty, v.clone())?); + row.push(compile_expr_field(&table, field_ty, v.clone())?.strip_table()); } - - values.push(row); + values.push(row.into()); } + let values = values.into(); + Ok(SqlAst::Insert { table: table.root, columns, @@ -744,19 +746,19 @@ fn compile_update( let table = From::new(tx.find_table(db, table)?); let selection = compile_where(&table, selection)?; - let mut x = HashMap::with_capacity(assignments.len()); - + let mut assigns = IntMap::with_capacity(assignments.len()); for col in assignments { let name: String = col.id.iter().map(|x| x.to_string()).collect(); let (field_name, field_ty) = table.find_field(&name)?; + let col_id = field_name.col; - let value = compile_expr_field(&table, Some(field_ty), col.value)?; - x.insert(field_name, value); + let value = compile_expr_field(&table, Some(field_ty), col.value)?.strip_table(); + assigns.insert(col_id, value); } Ok(SqlAst::Update { table: table.root, - assignments: x, + assignments: assigns, selection, }) } @@ -916,7 +918,7 @@ fn compile_create_table(table: Table, cols: Vec) -> Result Result { } /// Compiles a `WHERE ...` clause -fn compile_where(mut q: QueryExpr, filter: Selection) -> QueryExpr { +fn compile_where(mut q: QueryExpr, filter: Selection) -> Result { for op in filter.clause.flatten_ands() { - q = q.with_select(op); + q = q.with_select(op)?; } - q + Ok(q) } /// Compiles a `SELECT ...` clause -fn compile_select(table: From, project: Vec, selection: Option) -> Result { +fn compile_select(table: From, project: Box<[Column]>, selection: Option) -> Result { let mut not_found = Vec::with_capacity(project.len()); let mut col_ids = Vec::new(); let mut qualified_wildcards = Vec::new(); //Match columns to their tables... - for select_item in project { + for select_item in Vec::from(project) { match select_item { Column::UnnamedExpr(x) => match expr_for_projection(&table, x) { Ok(field) => col_ids.push(field), @@ -116,24 +117,24 @@ fn compile_select(table: From, project: Vec, selection: Option) -> DbTable { - let mut columns = Vec::with_capacity(field_names.len()); - let cols = field_names - .into_iter() - .filter_map(|col| table.get_column_by_field(col)) +fn compile_columns(table: &TableSchema, cols: &[ColId]) -> DbTable { + let mut columns = Vec::with_capacity(cols.len()); + let cols = cols + .iter() + .filter_map(|col| table.get_column(col.idx())) .map(|col| relation::Column::new(FieldName::new(table.table_id, col.col_pos), col.col_type.clone())); columns.extend(cols); @@ -148,18 +149,18 @@ fn compile_columns(table: &TableSchema, field_names: Vec) -> DbTable } /// Compiles a `INSERT ...` clause -fn compile_insert(table: &TableSchema, columns: Vec, values: Vec>) -> CrudExpr { - let table = compile_columns(table, columns); +fn compile_insert(table: &TableSchema, cols: &[ColId], values: Box<[Box<[ColExpr]>]>) -> CrudExpr { + let table = compile_columns(table, cols); let mut rows = Vec::with_capacity(values.len()); - for x in values { + for x in Vec::from(values) { let mut row = Vec::with_capacity(x.len()); - for v in x { + for v in Vec::from(x) { match v { - FieldExpr::Name(x) => { + ColExpr::Col(x) => { todo!("Deal with idents in insert?: {}", x) } - FieldExpr::Value(x) => { + ColExpr::Value(x) => { row.push(x); } } @@ -171,34 +172,34 @@ fn compile_insert(table: &TableSchema, columns: Vec, values: Vec, selection: Option) -> CrudExpr { +fn compile_delete(table: Arc, selection: Option) -> Result { let query = QueryExpr::new(&*table); let query = if let Some(filter) = selection { - compile_where(query, filter) + compile_where(query, filter)? } else { query }; - CrudExpr::Delete { query } + Ok(CrudExpr::Delete { query }) } /// Compiles a `UPDATE ...` clause fn compile_update( table: Arc, - assignments: HashMap, + assignments: IntMap, selection: Option, -) -> CrudExpr { +) -> Result { let query = QueryExpr::new(&*table); let delete = if let Some(filter) = selection { - compile_where(query, filter) + compile_where(query, filter)? } else { query }; - CrudExpr::Update { delete, assignments } + Ok(CrudExpr::Update { delete, assignments }) } /// Compiles a `CREATE TABLE ...` clause -fn compile_create_table(table: TableDef) -> CrudExpr { +fn compile_create_table(table: Box) -> CrudExpr { CrudExpr::CreateTable { table } } @@ -210,13 +211,13 @@ fn compile_statement(db: &RelationalDB, statement: SqlAst) -> Result CrudExpr::Query(compile_select(from, project, selection)?), - SqlAst::Insert { table, columns, values } => compile_insert(&table, columns, values), + SqlAst::Insert { table, columns, values } => compile_insert(&table, &columns, values), SqlAst::Update { table, assignments, selection, - } => compile_update(table, assignments, selection), - SqlAst::Delete { table, selection } => compile_delete(table, selection), + } => compile_update(table, assignments, selection)?, + SqlAst::Delete { table, selection } => compile_delete(table, selection)?, SqlAst::CreateTable { table } => compile_create_table(table), SqlAst::Drop { name, kind } => CrudExpr::Drop { name, kind }, SqlAst::SetVar { name, value } => CrudExpr::SetVar { name, value }, @@ -643,14 +644,12 @@ mod tests { panic!("unexpected operator {:#?}", query[0]); }; - let ColumnOp::Field(FieldExpr::Name(FieldName { table, col })) = **lhs else { + let ColumnOp::Col(ColExpr::Col(col)) = **lhs else { panic!("unexpected left hand side {:#?}", **lhs); }; - - assert_eq!(table, lhs_id); assert_eq!(col, 0.into()); - let ColumnOp::Field(FieldExpr::Value(AlgebraicValue::U64(3))) = **rhs else { + let ColumnOp::Col(ColExpr::Value(AlgebraicValue::U64(3))) = **rhs else { panic!("unexpected right hand side {:#?}", **rhs); }; @@ -728,14 +727,12 @@ mod tests { panic!("unexpected operator {:#?}", rhs.query[0]); }; - let ColumnOp::Field(FieldExpr::Name(FieldName { table, col })) = **lhs else { + let ColumnOp::Col(ColExpr::Col(col)) = **lhs else { panic!("unexpected left hand side {:#?}", **lhs); }; - - assert_eq!(table, rhs_id); assert_eq!(col, 1.into()); - let ColumnOp::Field(FieldExpr::Value(AlgebraicValue::U64(3))) = **rhs else { + let ColumnOp::Col(ColExpr::Value(AlgebraicValue::U64(3))) = **rhs else { panic!("unexpected right hand side {:#?}", **rhs); }; Ok(()) @@ -888,14 +885,12 @@ mod tests { panic!("unexpected operator {:#?}", rhs[0]); }; - let ColumnOp::Field(FieldExpr::Name(FieldName { table, col })) = **field else { + let ColumnOp::Col(ColExpr::Col(col)) = **field else { panic!("unexpected left hand side {:#?}", field); }; - - assert_eq!(table, rhs_id); assert_eq!(col, 2.into()); - let ColumnOp::Field(FieldExpr::Value(AlgebraicValue::U64(3))) = **value else { + let ColumnOp::Col(ColExpr::Value(AlgebraicValue::U64(3))) = **value else { panic!("unexpected right hand side {:#?}", value); }; Ok(()) @@ -976,14 +971,12 @@ mod tests { panic!("unexpected operator {:#?}", rhs[0]); }; - let ColumnOp::Field(FieldExpr::Name(FieldName { table, col })) = **field else { + let ColumnOp::Col(ColExpr::Col(col)) = **field else { panic!("unexpected left hand side {:#?}", field); }; - - assert_eq!(table, rhs_id); assert_eq!(col, 2.into()); - let ColumnOp::Field(FieldExpr::Value(AlgebraicValue::U64(3))) = **value else { + let ColumnOp::Col(ColExpr::Value(AlgebraicValue::U64(3))) = **value else { panic!("unexpected right hand side {:#?}", value); }; Ok(()) diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 3591df8ab7..24031fdb30 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -132,6 +132,7 @@ pub(crate) mod tests { use crate::db::relational_db::tests_utils::TestDB; use crate::vm::tests::create_table_with_rows; use spacetimedb_lib::error::{ResultTest, TestError}; + use spacetimedb_lib::relation::ColExpr; use spacetimedb_primitives::{col_list, ColId}; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::relation::Header; @@ -287,8 +288,7 @@ pub(crate) mod tests { assert_eq!(result.len(), 1, "Not return results"); let result = result.first().unwrap().clone(); // The expected result. - let col = table.head.fields[0].field; - let inv = table.head.project(&[col]).unwrap(); + let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); let row = product![1u64]; let input = MemTable::new(inv.into(), table.table_access, vec![row]); @@ -311,8 +311,7 @@ pub(crate) mod tests { let result = result.first().unwrap().clone(); // The expected result. - let col = table.head.fields[0].field; - let inv = table.head.project(&[col]).unwrap(); + let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); let row = product![1u64]; let input = MemTable::new(inv.into(), table.table_access, vec![row]); @@ -338,8 +337,7 @@ pub(crate) mod tests { let mut result = result.first().unwrap().clone(); result.data.sort(); //The expected result - let col = table.head.fields[0].field; - let inv = table.head.project(&[col]).unwrap(); + let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); let input = MemTable::new(inv.into(), table.table_access, vec![product![1u64], product![2u64]]); @@ -364,8 +362,7 @@ pub(crate) mod tests { let mut result = result.first().unwrap().clone(); result.data.sort(); // The expected result. - let col = table.head.fields[0].field; - let inv = table.head.project(&[col]).unwrap(); + let inv = table.head.project(&[ColExpr::Col(0.into())]).unwrap(); let input = MemTable::new(inv.into(), table.table_access, vec![product![1u64], product![2u64]]); diff --git a/crates/core/src/subscription/execution_unit.rs b/crates/core/src/subscription/execution_unit.rs index 543b5d6e91..ddef32b3d7 100644 --- a/crates/core/src/subscription/execution_unit.rs +++ b/crates/core/src/subscription/execution_unit.rs @@ -115,18 +115,7 @@ impl ExecutionUnit { source.is_db_table(), "The plan passed to `compile_select_eval_incr` must read from `DbTable`s, but found in-mem table" ); - // NOTE: The `eval_incr_plan` will reference a `SourceExpr::InMemory` - // with `row_count: RowCount::exact(0)`. - // This is inaccurate; while we cannot predict the exact number of rows, - // we know that it will never be 0, - // as we wouldn't have a [`DatabaseTableUpdate`] with no changes. - // - // Our current query planner doesn't use the `row_count` in any meaningful way, - // so this is fine. - // Some day down the line, when we have a real query planner, - // we may need to provide a row count estimation that is, if not accurate, - // at least less specifically inaccurate. - let source = SourceExpr::from_mem_table(source.head().clone(), source.table_access(), 0, SourceId(0)); + let source = SourceExpr::from_mem_table(source.head().clone(), source.table_access(), SourceId(0)); let query = expr.query.clone(); QueryExpr { source, query } } @@ -215,41 +204,29 @@ impl ExecutionUnit { /// Evaluate this execution unit against the database using the json format. #[tracing::instrument(skip_all)] - pub fn eval_json( - &self, - ctx: &ExecutionContext, - db: &RelationalDB, - tx: &Tx, - sql: &str, - ) -> Result, DBError> { + pub fn eval_json(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx, sql: &str) -> Option { let table_row_operations = Self::eval_query_expr(ctx, db, tx, &self.eval_plan, sql, |row| { rel_value_to_table_row_op_json(row, OpType::Insert) - })?; - Ok((!table_row_operations.is_empty()).then(|| TableUpdateJson { + }); + (!table_row_operations.is_empty()).then(|| TableUpdateJson { table_id: self.return_table().into(), table_name: self.return_name(), table_row_operations, - })) + }) } /// Evaluate this execution unit against the database using the binary format. #[tracing::instrument(skip_all)] - pub fn eval_binary( - &self, - ctx: &ExecutionContext, - db: &RelationalDB, - tx: &Tx, - sql: &str, - ) -> Result, DBError> { + pub fn eval_binary(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx, sql: &str) -> Option { let mut scratch = Vec::new(); let table_row_operations = Self::eval_query_expr(ctx, db, tx, &self.eval_plan, sql, |row| { rel_value_to_table_row_op_binary(&mut scratch, &row, OpType::Insert) - })?; - Ok((!table_row_operations.is_empty()).then(|| TableUpdate { + }); + (!table_row_operations.is_empty()).then(|| TableUpdate { table_id: self.return_table().into(), table_name: self.return_name().into(), table_row_operations, - })) + }) } fn eval_query_expr( @@ -259,12 +236,9 @@ impl ExecutionUnit { eval_plan: &QueryExpr, sql: &str, convert: impl FnMut(RelValue<'_>) -> T, - ) -> Result, DBError> { - let tx: TxMode = tx.into(); + ) -> Vec { let _slow_query = SlowQueryLogger::subscription(ctx, sql).log_guard(); - let query = build_query(ctx, db, &tx, eval_plan, &mut NoInMemUsed)?; - let ops = query.collect_vec(convert)?; - Ok(ops) + build_query(ctx, db, &tx.into(), eval_plan, &mut NoInMemUsed).collect_vec(convert) } /// Evaluate this execution unit against the given delta tables. @@ -275,18 +249,18 @@ impl ExecutionUnit { tx: &'a TxMode<'a>, sql: &'a str, tables: impl 'a + Clone + Iterator, - ) -> Result>, DBError> { + ) -> Option> { let _slow_query = SlowQueryLogger::incremental_updates(ctx, sql).log_guard(); let updates = match &self.eval_incr_plan { - EvalIncrPlan::Select(plan) => Self::eval_incr_query_expr(ctx, db, tx, tables, plan, self.return_table())?, - EvalIncrPlan::Semijoin(plan) => plan.eval(ctx, db, tx, tables)?, + EvalIncrPlan::Select(plan) => Self::eval_incr_query_expr(ctx, db, tx, tables, plan, self.return_table()), + EvalIncrPlan::Semijoin(plan) => plan.eval(ctx, db, tx, tables), }; - Ok(updates.has_updates().then(|| DatabaseTableUpdateRelValue { + updates.has_updates().then(|| DatabaseTableUpdateRelValue { table_id: self.return_table(), table_name: self.return_name(), updates, - })) + }) } fn eval_query_expr_against_memtable<'a>( @@ -295,12 +269,12 @@ impl ExecutionUnit { tx: &'a TxMode, mem_table: &'a [ProductValue], eval_incr_plan: &'a QueryExpr, - ) -> Result>, DBError> { + ) -> Box> { // Provide the updates from `table`. let sources = &mut Some(mem_table.iter().map(RelValue::ProjRef)); // Evaluate the saved plan against the new updates, // returning an iterator over the selected rows. - build_query(ctx, db, tx, eval_incr_plan, sources).map_err(Into::into) + build_query(ctx, db, tx, eval_incr_plan, sources) } fn eval_incr_query_expr<'a>( @@ -310,7 +284,7 @@ impl ExecutionUnit { tables: impl Iterator, eval_incr_plan: &'a QueryExpr, return_table: TableId, - ) -> Result, DBError> { + ) -> UpdatesRelValue<'a> { assert!( eval_incr_plan.source.is_mem_table(), "Expected in-mem table in `eval_incr_plan`, but found `DbTable`" @@ -324,23 +298,22 @@ impl ExecutionUnit { // without forgetting which are inserts and which are deletes. // Previously, we used to add such a column `"__op_type: AlgebraicType::U8"`. if !table.inserts.is_empty() { - let query = Self::eval_query_expr_against_memtable(ctx, db, tx, &table.inserts, eval_incr_plan)?; - Self::collect_rows(&mut inserts, query)?; + let query = Self::eval_query_expr_against_memtable(ctx, db, tx, &table.inserts, eval_incr_plan); + Self::collect_rows(&mut inserts, query); } if !table.deletes.is_empty() { - let query = Self::eval_query_expr_against_memtable(ctx, db, tx, &table.deletes, eval_incr_plan)?; - Self::collect_rows(&mut deletes, query)?; + let query = Self::eval_query_expr_against_memtable(ctx, db, tx, &table.deletes, eval_incr_plan); + Self::collect_rows(&mut deletes, query); } } - Ok(UpdatesRelValue { deletes, inserts }) + UpdatesRelValue { deletes, inserts } } /// Collect the results of `query` into a vec `sink`. - fn collect_rows<'a>(sink: &mut Vec>, mut query: Box>) -> Result<(), DBError> { - while let Some(row) = query.next()? { + fn collect_rows<'a>(sink: &mut Vec>, mut query: Box>) { + while let Some(row) = query.next() { sink.push(row); } - Ok(()) } /// The estimated number of rows returned by this execution unit. diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 74580a149f..5b0d2a8a55 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -109,7 +109,7 @@ impl ModuleSubscriptions { &config, )?; - let database_update = execution_set.eval(&ctx, sender.protocol, &self.relational_db, &tx)?; + let database_update = execution_set.eval(&ctx, sender.protocol, &self.relational_db, &tx); WORKER_METRICS .initial_subscription_evals diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index eac0f1bee6..73e319a14c 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -146,15 +146,8 @@ impl SubscriptionManager { .par_iter() .filter_map(|(&hash, tables)| { let unit = self.queries.get(hash)?; - match unit.eval_incr(&ctx, db, tx, &unit.sql, tables.iter().copied()) { - Ok(None) => None, - Ok(Some(table)) => Some((hash, table)), - Err(err) => { - // TODO: log an id for the subscription somehow as well - tracing::error!(err = &err as &dyn std::error::Error, "subscription eval_incr failed"); - None - } - } + unit.eval_incr(&ctx, db, tx, &unit.sql, tables.iter().copied()) + .map(|table| (hash, table)) }) // If N clients are subscribed to a query, // we copy the DatabaseTableUpdate N times, diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 4c397bfa30..780ec83b35 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -202,10 +202,8 @@ mod tests { let (schema, table, data, q) = make_data(db, tx, table_name, &head, &row)?; - // For filtering out the hidden field `OP_TYPE_FIELD_NAME` let fields = &[0, 1].map(|c| FieldName::new(schema.table_id, c.into()).into()); - - let q = q.with_project(fields, None).unwrap(); + let q = q.with_project(fields.into(), None).unwrap(); Ok((schema, table, data, q)) } @@ -220,9 +218,8 @@ mod tests { let (schema, table, data, q) = make_data(db, tx, table_name, &head, &row)?; - let fields = &[0, 1].map(|c| FieldName::new(schema.table_id, c.into()).into()); - - let q = q.with_project(fields, None).unwrap(); + let fields = [0, 1].map(|c| FieldName::new(schema.table_id, c.into()).into()); + let q = q.with_project(fields.into(), None).unwrap(); Ok((schema, table, data, q)) } @@ -276,7 +273,7 @@ mod tests { let ctx = &ExecutionContext::incremental_update(db.address(), SlowQueryConfig::default()); let tx = &tx.into(); let update = update.tables.iter().collect::>(); - let result = s.eval_incr(ctx, db, tx, &update)?; + let result = s.eval_incr(ctx, db, tx, &update); assert_eq!( result.tables.len(), total_tables, @@ -303,7 +300,7 @@ mod tests { total_tables: usize, rows: &[ProductValue], ) -> ResultTest<()> { - let result = s.eval(ctx, Protocol::Binary, db, tx)?.tables.unwrap_left(); + let result = s.eval(ctx, Protocol::Binary, db, tx).tables.unwrap_left(); assert_eq!( result.len(), total_tables, @@ -367,7 +364,7 @@ mod tests { let ctx = &ExecutionContext::incremental_update(db.address(), SlowQueryConfig::default()); let tx = (&tx).into(); let update = update.tables.iter().collect::>(); - let result = query.eval_incr(ctx, &db, &tx, &update)?; + let result = query.eval_incr(ctx, &db, &tx, &update); assert_eq!(result.tables.len(), 1); @@ -397,7 +394,9 @@ mod tests { let q_1 = q.clone(); check_query(&db, &table, &tx, &q_1, &data)?; - let q_2 = q.with_select_cmp(OpCmp::Eq, FieldName::new(schema.table_id, 0.into()), scalar(1u64)); + let q_2 = q + .with_select_cmp(OpCmp::Eq, FieldName::new(schema.table_id, 0.into()), scalar(1u64)) + .unwrap(); check_query(&db, &table, &tx, &q_2, &data)?; Ok(()) @@ -420,11 +419,9 @@ mod tests { check_query(&db, &table, &tx, &q, &data)?; // SELECT * FROM inventory WHERE inventory_id = 1 - let q_id = QueryExpr::new(&*schema).with_select_cmp( - OpCmp::Eq, - FieldName::new(schema.table_id, 0.into()), - scalar(1u64), - ); + let q_id = QueryExpr::new(&*schema) + .with_select_cmp(OpCmp::Eq, FieldName::new(schema.table_id, 0.into()), scalar(1u64)) + .unwrap(); let s = singleton_execution_set(q_id, "SELECT * FROM inventory WHERE inventory_id = 1".into())?; @@ -723,7 +720,7 @@ mod tests { db.with_read_only(ctx, |tx| { let tx = (&*tx).into(); let update = update.tables.iter().collect::>(); - let result = query.eval_incr(ctx, db, &tx, &update)?; + let result = query.eval_incr(ctx, db, &tx, &update); let tables = result .tables .into_iter() diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 6468f7bfb9..15641e1ee5 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -41,7 +41,6 @@ use spacetimedb_primitives::TableId; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::db::error::AuthError; use spacetimedb_sats::relation::DbTable; -use spacetimedb_vm::errors::ErrorVm; use spacetimedb_vm::expr::{self, AuthAccess, IndexJoin, Query, QueryExpr, SourceExpr, SourceProvider, SourceSet}; use spacetimedb_vm::rel_ops::RelOps; use spacetimedb_vm::relation::{MemTable, RelValue}; @@ -125,8 +124,6 @@ impl AsRef for SupportedQuery { } } -type ResRV<'a> = Result, ErrorVm>; - /// Evaluates `query` and returns all the updates. fn eval_updates<'a>( ctx: &'a ExecutionContext, @@ -134,9 +131,9 @@ fn eval_updates<'a>( tx: &'a TxMode<'a>, query: &'a QueryExpr, mut sources: impl SourceProvider<'a>, -) -> Result>, DBError> { - let mut query = build_query(ctx, db, tx, query, &mut sources)?; - Ok(iter::from_fn(move || query.next().transpose())) +) -> impl 'a + Iterator> { + let mut query = build_query(ctx, db, tx, query, &mut sources); + iter::from_fn(move || query.next()) } /// A [`query::Supported::Semijoin`] compiled for incremental evaluations. @@ -268,7 +265,7 @@ impl IncrementalJoin { db: &'a RelationalDB, tx: &'a TxMode<'a>, lhs: impl 'a + Iterator, - ) -> Result>, DBError> { + ) -> impl Iterator> { eval_updates(ctx, db, tx, self.plan_for_delta_lhs(), Some(lhs.map(RelValue::ProjRef))) } @@ -279,7 +276,7 @@ impl IncrementalJoin { db: &'a RelationalDB, tx: &'a TxMode<'a>, rhs: impl 'a + Iterator, - ) -> Result>, DBError> { + ) -> impl Iterator> { eval_updates(ctx, db, tx, self.plan_for_delta_rhs(), Some(rhs.map(RelValue::ProjRef))) } @@ -291,7 +288,7 @@ impl IncrementalJoin { tx: &'a TxMode<'a>, lhs: impl 'a + Iterator, rhs: impl 'a + Iterator, - ) -> Result>, DBError> { + ) -> impl Iterator> { let is = Either::Left(lhs.map(RelValue::ProjRef)); let ps = Either::Right(rhs.map(RelValue::ProjRef)); let sources: SourceSet<_, 2> = if self.return_index_rows { [is, ps] } else { [ps, is] }.into(); @@ -356,7 +353,7 @@ impl IncrementalJoin { db: &'a RelationalDB, tx: &'a TxMode<'a>, updates: impl 'a + Clone + Iterator, - ) -> Result, DBError> { + ) -> UpdatesRelValue<'a> { // Find any updates to the tables mentioned by `self` and group them into [`JoinSide`]s. // // The supplied updates are assumed to be the full set of updates from a single transaction. @@ -393,84 +390,78 @@ impl IncrementalJoin { let has_rhs_deletes = rhs_deletes.peek().is_some(); let has_rhs_inserts = rhs_inserts.peek().is_some(); if !has_lhs_deletes && !has_lhs_inserts && !has_rhs_deletes && !has_rhs_inserts { - return Ok(<_>::default()); + return <_>::default(); } // Compute the incremental join // ===================================================================== - fn collect_set>>( + fn collect_set>( produce_if: bool, - producer: impl FnOnce() -> Result, - ) -> Result, DBError> { - Ok(if produce_if { - let iter = producer()?; - let mut set = HashSet::new(); - for x in iter { - set.insert(x?); - } - set + producer: impl FnOnce() -> I, + ) -> HashSet { + if produce_if { + producer().collect() } else { HashSet::new() - }) + } } - fn make_iter>>( + fn make_iter>( produce_if: bool, - producer: impl FnOnce() -> Result, - ) -> Result>, DBError> { - Ok(if produce_if { - let iter = producer()?; - Either::Left(iter) + producer: impl FnOnce() -> I, + ) -> impl Iterator { + if produce_if { + Either::Left(producer()) } else { Either::Right(iter::empty()) - }) + } } // (1) A+ x B(t) let j1_lhs_ins = lhs_inserts.clone(); - let join_1 = make_iter(has_lhs_inserts, || self.eval_lhs(ctx, db, tx, j1_lhs_ins))?; + let join_1 = make_iter(has_lhs_inserts, || self.eval_lhs(ctx, db, tx, j1_lhs_ins)); // (2) A- x B(t) let j2_lhs_del = lhs_deletes.clone(); - let mut join_2 = collect_set(has_lhs_deletes, || self.eval_lhs(ctx, db, tx, j2_lhs_del))?; + let mut join_2 = collect_set(has_lhs_deletes, || self.eval_lhs(ctx, db, tx, j2_lhs_del)); // (3) A- x B+ let j3_lhs_del = lhs_deletes.clone(); let j3_rhs_ins = rhs_inserts.clone(); let join_3 = make_iter(has_lhs_deletes && has_rhs_inserts, || { self.eval_all(ctx, db, tx, j3_lhs_del, j3_rhs_ins) - })?; + }); // (4) A- x B- let j4_rhs_del = rhs_deletes.clone(); let join_4 = make_iter(has_lhs_deletes && has_rhs_deletes, || { self.eval_all(ctx, db, tx, lhs_deletes, j4_rhs_del) - })?; + }); // (5) A(t) x B+ let j5_rhs_ins = rhs_inserts.clone(); - let mut join_5 = collect_set(has_rhs_inserts, || self.eval_rhs(ctx, db, tx, j5_rhs_ins))?; + let mut join_5 = collect_set(has_rhs_inserts, || self.eval_rhs(ctx, db, tx, j5_rhs_ins)); // (6) A(t) x B- let j6_rhs_del = rhs_deletes.clone(); - let mut join_6 = collect_set(has_rhs_deletes, || self.eval_rhs(ctx, db, tx, j6_rhs_del))?; + let mut join_6 = collect_set(has_rhs_deletes, || self.eval_rhs(ctx, db, tx, j6_rhs_del)); // (7) A+ x B+ let j7_lhs_ins = lhs_inserts.clone(); let join_7 = make_iter(has_lhs_inserts && has_rhs_inserts, || { self.eval_all(ctx, db, tx, j7_lhs_ins, rhs_inserts) - })?; + }); // (8) A+ x B- let join_8 = make_iter(has_lhs_inserts && has_rhs_deletes, || { self.eval_all(ctx, db, tx, lhs_inserts, rhs_deletes) - })?; + }); // A- x B(s) = A- x B(t) \ A- x B+ for row in join_3 { - join_2.remove(&row?); + join_2.remove(&row); } // A(s) x B+ = A(t) x B+ \ A+ x B+ for row in join_7 { - join_5.remove(&row?); + join_5.remove(&row); } // A(s) x B- = A(t) x B- \ A+ x B- for row in join_8 { - join_6.remove(&row?); + join_6.remove(&row); } join_5.retain(|row| !join_6.remove(row)); @@ -479,18 +470,18 @@ impl IncrementalJoin { let mut deletes = Vec::new(); deletes.extend(join_2); for row in join_4 { - deletes.push(row?); + deletes.push(row); } deletes.extend(join_6); // Collect inserts: let mut inserts = Vec::new(); for row in join_1 { - inserts.push(row?); + inserts.push(row); } inserts.extend(join_5); - Ok(UpdatesRelValue { deletes, inserts }) + UpdatesRelValue { deletes, inserts } } } @@ -530,32 +521,32 @@ impl ExecutionSet { protocol: Protocol, db: &RelationalDB, tx: &Tx, - ) -> Result { + ) -> ProtocolDatabaseUpdate { let tables = match protocol { - Protocol::Binary => Either::Left(self.eval_binary(ctx, db, tx)?), - Protocol::Text => Either::Right(self.eval_json(ctx, db, tx)?), + Protocol::Binary => Either::Left(self.eval_binary(ctx, db, tx)), + Protocol::Text => Either::Right(self.eval_json(ctx, db, tx)), }; - Ok(ProtocolDatabaseUpdate { tables }) + ProtocolDatabaseUpdate { tables } } #[tracing::instrument(skip_all)] - fn eval_json(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx) -> Result, DBError> { + fn eval_json(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx) -> Vec { // evaluate each of the execution units in this ExecutionSet in parallel self.exec_units // if you need eval to run single-threaded for debugging, change this to .iter() .par_iter() - .filter_map(|unit| unit.eval_json(ctx, db, tx, &unit.sql).transpose()) - .collect::, _>>() + .filter_map(|unit| unit.eval_json(ctx, db, tx, &unit.sql)) + .collect() } #[tracing::instrument(skip_all)] - fn eval_binary(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx) -> Result, DBError> { + fn eval_binary(&self, ctx: &ExecutionContext, db: &RelationalDB, tx: &Tx) -> Vec { // evaluate each of the execution units in this ExecutionSet in parallel self.exec_units // if you need eval to run single-threaded for debugging, change this to .iter() .par_iter() - .filter_map(|unit| unit.eval_binary(ctx, db, tx, &unit.sql).transpose()) - .collect::, _>>() + .filter_map(|unit| unit.eval_binary(ctx, db, tx, &unit.sql)) + .collect() } #[tracing::instrument(skip_all)] @@ -565,14 +556,14 @@ impl ExecutionSet { db: &'a RelationalDB, tx: &'a TxMode<'a>, database_update: &'a [&'a DatabaseTableUpdate], - ) -> Result, DBError> { + ) -> DatabaseUpdateRelValue<'a> { let mut tables = Vec::new(); for unit in &self.exec_units { - if let Some(table) = unit.eval_incr(ctx, db, tx, &unit.sql, database_update.iter().copied())? { + if let Some(table) = unit.eval_incr(ctx, db, tx, &unit.sql, database_update.iter().copied()) { tables.push(table); } } - Ok(DatabaseUpdateRelValue { tables }) + DatabaseUpdateRelValue { tables } } /// The estimated number of rows returned by this execution set. diff --git a/crates/core/src/vm.rs b/crates/core/src/vm.rs index a52aa38c37..d63e71ba25 100644 --- a/crates/core/src/vm.rs +++ b/crates/core/src/vm.rs @@ -1,21 +1,21 @@ //! The [DbProgram] that execute arbitrary queries & code against the database. use crate::config::DatabaseConfig; -use crate::db::cursor::{IndexCursor, TableCursor}; -use crate::db::datastore::locking_tx_datastore::tx::TxId; use crate::db::datastore::locking_tx_datastore::IterByColRange; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::DBError; use crate::estimation; use crate::execution_context::ExecutionContext; -use core::ops::RangeBounds; +use core::ops::{Bound, RangeBounds}; use itertools::Itertools; -use spacetimedb_data_structures::map::HashMap; +use spacetimedb_data_structures::map::IntMap; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::*; use spacetimedb_sats::db::def::TableDef; -use spacetimedb_sats::relation::{DbTable, FieldExpr, FieldName, Header, RowCount}; +use spacetimedb_sats::relation::{ColExpr, DbTable}; use spacetimedb_sats::{AlgebraicValue, ProductValue}; +use spacetimedb_table::static_assert_size; +use spacetimedb_table::table::RowRef; use spacetimedb_vm::errors::ErrorVm; use spacetimedb_vm::eval::{build_project, build_select, join_inner, IterRows}; use spacetimedb_vm::expr::*; @@ -23,8 +23,6 @@ use spacetimedb_vm::iterators::RelIter; use spacetimedb_vm::program::{ProgramVm, Sources}; use spacetimedb_vm::rel_ops::{EmptyRelOps, RelOps}; use spacetimedb_vm::relation::{MemTable, RelValue}; -use std::ops::Bound; -use std::sync::Arc; pub enum TxMode<'a> { MutTx(&'a mut MutTx), @@ -79,7 +77,7 @@ pub fn build_query<'a>( tx: &'a TxMode<'a>, query: &'a QueryExpr, sources: &mut impl SourceProvider<'a>, -) -> Result>, ErrorVm> { +) -> Box> { let db_table = query.source.is_db_table(); // We're incrementally building a query iterator by applying each operation in the `query.query`. @@ -101,7 +99,6 @@ pub fn build_query<'a>( let result_or_base = |sources: &mut _, result: &mut Option<_>| { result .take() - .map(Ok) .unwrap_or_else(|| get_table(ctx, db, tx, &query.source, sources)) }; @@ -114,14 +111,14 @@ pub fn build_query<'a>( // return an empty iterator. // This avoids a panic in `BTreeMap`'s `NodeRef::search_tree_for_bifurcation`, // which is very unhappy about unsatisfiable bounds. - Box::new(EmptyRelOps::new(table.head.clone())) as Box> + Box::new(EmptyRelOps) as Box> } else { let bounds = (bounds.start_bound(), bounds.end_bound()); - iter_by_col_range(ctx, db, tx, table, columns.clone(), bounds)? + iter_by_col_range(ctx, db, tx, table, columns.clone(), bounds) } } Query::IndexScan(index_scan) => { - let result = result_or_base(sources, &mut result)?; + let result = result_or_base(sources, &mut result); let cols = &index_scan.columns; let bounds = &index_scan.bounds; @@ -136,11 +133,11 @@ pub fn build_query<'a>( // The current behavior is a hack // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha, // and a more invasive change was infeasible. - Box::new(EmptyRelOps::new(index_scan.table.head.clone())) as Box> + Box::new(EmptyRelOps) as Box> } else if cols.is_singleton() { // For singleton constraints, we compare the column directly against `bounds`. let head = cols.head().idx(); - let iter = result.select(move |row| Ok(bounds.contains(&*row.read_column(head).unwrap()))); + let iter = result.select(move |row| bounds.contains(&*row.read_column(head).unwrap())); Box::new(iter) as Box> } else { // For multi-col constraints, these are stored as bounds of product values, @@ -155,18 +152,16 @@ pub fn build_query<'a>( // and compare against the column in the row. // All columns must match to include the row, // which is essentially the same as a big `AND` of `ColumnOp`s. - Ok(cols.iter().enumerate().all(|(idx, col)| { + cols.iter().enumerate().all(|(idx, col)| { let start_bound = start_bound.map(|pv| &pv[idx]); let end_bound = end_bound.map(|pv| &pv[idx]); let read_col = row.read_column(col.idx()).unwrap(); (start_bound, end_bound).contains(&*read_col) - })) + }) })) } } - Query::IndexJoin(_) if result.is_some() => { - return Err(anyhow::anyhow!("Invalid query: `IndexJoin` must be the first operator").into()) - } + Query::IndexJoin(_) if result.is_some() => panic!("Invalid query: `IndexJoin` must be the first operator"), Query::IndexJoin(IndexJoin { probe_side, probe_col, @@ -174,26 +169,42 @@ pub fn build_query<'a>( index_select, index_col, return_index_rows, - }) => Box::new(IndexSemiJoin { - ctx, - db, - tx, - probe_side: build_query(ctx, db, tx, probe_side, sources)?, - probe_col: *probe_col, - index_header: index_side.head(), - index_select, + }) => { + let probe_side = build_query(ctx, db, tx, probe_side, sources); // The compiler guarantees that the index side is a db table, // and therefore this unwrap is always safe. - index_table: index_side.table_id().unwrap(), - index_col: *index_col, - index_iter: None, - return_index_rows: *return_index_rows, - }), - Query::Select(cmp) => build_select(result_or_base(sources, &mut result)?, cmp), - Query::Project(proj) => build_project(result_or_base(sources, &mut result)?, proj), + let index_table = index_side.table_id().unwrap(); + + if *return_index_rows { + Box::new(IndexSemiJoinLeft { + ctx, + db, + tx, + probe_side, + probe_col: *probe_col, + index_select, + index_table, + index_col: *index_col, + index_iter: None, + }) as Box> + } else { + Box::new(IndexSemiJoinRight { + ctx, + db, + tx, + probe_side, + probe_col: *probe_col, + index_select, + index_table, + index_col: *index_col, + }) + } + } + Query::Select(cmp) => build_select(result_or_base(sources, &mut result), cmp), + Query::Project(proj) => build_project(result_or_base(sources, &mut result), proj), Query::JoinInner(join) => join_inner( - result_or_base(sources, &mut result)?, - build_query(ctx, db, tx, &join.rhs, sources)?, + result_or_base(sources, &mut result), + build_query(ctx, db, tx, &join.rhs, sources), join, ), }) @@ -216,37 +227,24 @@ fn get_table<'a>( ctx: &'a ExecutionContext, stdb: &'a RelationalDB, tx: &'a TxMode, - query: &SourceExpr, + query: &'a SourceExpr, sources: &mut impl SourceProvider<'a>, -) -> Result>, ErrorVm> { - Ok(match query { - SourceExpr::InMemory { - source_id, - header, - row_count, - .. - } => in_mem_to_rel_ops(sources, *source_id, header.clone(), *row_count), - SourceExpr::DbTable(x) => { - let iter = match tx { - TxMode::MutTx(tx) => stdb.iter_mut(ctx, tx, x.table_id)?, - TxMode::Tx(tx) => stdb.iter(ctx, tx, x.table_id)?, - }; - Box::new(TableCursor::new(x.clone(), iter)?) as Box> - } - }) -} - -// Extracts an in-memory table with `source_id` from `sources` and builds a query for the table. -fn in_mem_to_rel_ops<'a>( - sources: &mut impl SourceProvider<'a>, - source_id: SourceId, - head: Arc
, - rc: RowCount, ) -> Box> { - let source = sources.take_source(source_id).unwrap_or_else(|| { - panic!("Query plan specifies in-mem table for {source_id:?}, but found a `DbTable` or nothing") - }); - Box::new(RelIter::new(head, rc, source)) as Box> + match query { + // Extracts an in-memory table with `source_id` from `sources` and builds a query for the table. + SourceExpr::InMemory { source_id, .. } => build_iter( + sources + .take_source(*source_id) + .unwrap_or_else(|| { + panic!("Query plan specifies in-mem table for {source_id:?}, but found a `DbTable` or nothing") + }) + .into_iter(), + ), + SourceExpr::DbTable(db_table) => build_iter_from_db(match tx { + TxMode::MutTx(tx) => stdb.iter_mut(ctx, tx, db_table.table_id), + TxMode::Tx(tx) => stdb.iter(ctx, tx, db_table.table_id), + }), + } } fn iter_by_col_range<'a>( @@ -256,31 +254,36 @@ fn iter_by_col_range<'a>( table: &'a DbTable, columns: ColList, range: impl RangeBounds + 'a, -) -> Result + 'a>, ErrorVm> { - let iter = match tx { - TxMode::MutTx(tx) => db.iter_by_col_range_mut(ctx, tx, table.table_id, columns, range)?, - TxMode::Tx(tx) => db.iter_by_col_range(ctx, tx, table.table_id, columns, range)?, - }; - Ok(Box::new(IndexCursor::new(table, iter)?) as Box>) +) -> Box> { + build_iter_from_db(match tx { + TxMode::MutTx(tx) => db.iter_by_col_range_mut(ctx, tx, table.table_id, columns, range), + TxMode::Tx(tx) => db.iter_by_col_range(ctx, tx, table.table_id, columns, range), + }) } +fn build_iter_from_db<'a>(iter: Result>, DBError>) -> Box> { + build_iter(iter.expect(TABLE_ID_EXPECTED_VALID).map(RelValue::Row)) +} + +fn build_iter<'a>(iter: impl 'a + Iterator>) -> Box> { + Box::new(RelIter::new(iter)) as Box> +} + +const TABLE_ID_EXPECTED_VALID: &str = "all `table_id`s in compiled query should be valid"; + /// An index join operator that returns matching rows from the index side. -pub struct IndexSemiJoin<'a, 'c, Rhs: RelOps<'a>> { +pub struct IndexSemiJoinLeft<'a, 'c, Rhs: RelOps<'a>> { /// An iterator for the probe side. /// The values returned will be used to probe the index. pub probe_side: Rhs, /// The column whose value will be used to probe the index. pub probe_col: ColId, - /// The header for the index side of the join. - pub index_header: &'c Arc
, /// An optional predicate to evaluate over the matching rows of the index. pub index_select: &'c Option, /// The table id on which the index is defined. pub index_table: TableId, /// The column id for which the index is defined. pub index_col: ColId, - /// Is this a left or right semijoin? - pub return_index_rows: bool, /// An iterator for the index side. /// A new iterator will be instantiated for each row on the probe side. pub index_iter: Option>, @@ -292,64 +295,97 @@ pub struct IndexSemiJoin<'a, 'c, Rhs: RelOps<'a>> { ctx: &'a ExecutionContext, } -impl<'a, Rhs: RelOps<'a>> IndexSemiJoin<'a, '_, Rhs> { - fn filter(&self, index_row: &RelValue<'_>) -> Result { - Ok(if let Some(op) = &self.index_select { - op.compare(index_row, self.index_header)? - } else { - true - }) - } +static_assert_size!(IndexSemiJoinLeft>>, 312); - fn map(&self, index_row: RelValue<'a>, probe_row: Option>) -> RelValue<'a> { - if let Some(value) = probe_row { - if !self.return_index_rows { - return value; - } - } - index_row +impl<'a, Rhs: RelOps<'a>> IndexSemiJoinLeft<'a, '_, Rhs> { + fn filter(&self, index_row: &RelValue<'_>) -> bool { + self.index_select.as_ref().map_or(true, |op| op.compare(index_row)) } } -impl<'a, Rhs: RelOps<'a>> RelOps<'a> for IndexSemiJoin<'a, '_, Rhs> { - fn head(&self) -> &Arc
{ - if self.return_index_rows { - self.index_header - } else { - self.probe_side.head() +impl<'a, Rhs: RelOps<'a>> RelOps<'a> for IndexSemiJoinLeft<'a, '_, Rhs> { + fn next(&mut self) -> Option> { + // Return a value from the current index iterator, if not exhausted. + while let Some(index_row) = self.index_iter.as_mut().and_then(|iter| iter.next()).map(RelValue::Row) { + if self.filter(&index_row) { + return Some(index_row); + } } - } - fn next(&mut self) -> Result>, ErrorVm> { - // Return a value from the current index iterator, if not exhausted. - if self.return_index_rows { - while let Some(value) = self.index_iter.as_mut().and_then(|iter| iter.next()) { - let value = RelValue::Row(value); - if self.filter(&value)? { - return Ok(Some(self.map(value, None))); + // Otherwise probe the index with a row from the probe side. + let table_id = self.index_table; + let index_col = self.index_col; + let probe_col = self.probe_col.idx(); + while let Some(mut row) = self.probe_side.next() { + if let Some(value) = row.read_or_take_column(probe_col) { + let index_iter = match self.tx { + TxMode::MutTx(tx) => self.db.iter_by_col_range_mut(self.ctx, tx, table_id, index_col, value), + TxMode::Tx(tx) => self.db.iter_by_col_range(self.ctx, tx, table_id, index_col, value), + }; + let mut index_iter = index_iter.expect(TABLE_ID_EXPECTED_VALID); + while let Some(index_row) = index_iter.next().map(RelValue::Row) { + if self.filter(&index_row) { + self.index_iter = Some(index_iter); + return Some(index_row); + } } } } + None + } +} +/// An index join operator that returns matching rows from the index side. +pub struct IndexSemiJoinRight<'a, 'c, Rhs: RelOps<'a>> { + /// An iterator for the probe side. + /// The values returned will be used to probe the index. + pub probe_side: Rhs, + /// The column whose value will be used to probe the index. + pub probe_col: ColId, + /// An optional predicate to evaluate over the matching rows of the index. + pub index_select: &'c Option, + /// The table id on which the index is defined. + pub index_table: TableId, + /// The column id for which the index is defined. + pub index_col: ColId, + /// A reference to the database. + pub db: &'a RelationalDB, + /// A reference to the current transaction. + pub tx: &'a TxMode<'a>, + /// The execution context for the current transaction. + ctx: &'a ExecutionContext, +} + +static_assert_size!(IndexSemiJoinRight>>, 64); + +impl<'a, Rhs: RelOps<'a>> IndexSemiJoinRight<'a, '_, Rhs> { + fn filter(&self, index_row: &RelValue<'_>) -> bool { + self.index_select.as_ref().map_or(true, |op| op.compare(index_row)) + } +} + +impl<'a, Rhs: RelOps<'a>> RelOps<'a> for IndexSemiJoinRight<'a, '_, Rhs> { + fn next(&mut self) -> Option> { // Otherwise probe the index with a row from the probe side. let table_id = self.index_table; - let col_id = self.index_col; - while let Some(mut row) = self.probe_side.next()? { - if let Some(value) = row.read_or_take_column(self.probe_col.idx()) { - let mut index_iter = match self.tx { - TxMode::MutTx(tx) => self.db.iter_by_col_range_mut(self.ctx, tx, table_id, col_id, value)?, - TxMode::Tx(tx) => self.db.iter_by_col_range(self.ctx, tx, table_id, col_id, value)?, + let index_col = self.index_col; + let probe_col = self.probe_col.idx(); + while let Some(row) = self.probe_side.next() { + if let Some(value) = row.read_column(probe_col) { + let value = &*value; + let index_iter = match self.tx { + TxMode::MutTx(tx) => self.db.iter_by_col_range_mut(self.ctx, tx, table_id, index_col, value), + TxMode::Tx(tx) => self.db.iter_by_col_range(self.ctx, tx, table_id, index_col, value), }; - while let Some(value) = index_iter.next() { - let value = RelValue::Row(value); - if self.filter(&value)? { - self.index_iter = Some(index_iter); - return Ok(Some(self.map(value, Some(row)))); + let mut index_iter = index_iter.expect(TABLE_ID_EXPECTED_VALID); + while let Some(index_row) = index_iter.next().map(RelValue::Row) { + if self.filter(&index_row) { + return Some(row); } } } } - Ok(None) + None } } @@ -366,8 +402,8 @@ pub struct DbProgram<'db, 'tx> { /// reject the request if the estimated cardinality exceeds the limit. pub fn check_row_limit( queries: &QuerySet, - tx: &TxId, - row_est: impl Fn(&QuerySet, &TxId) -> u64, + tx: &Tx, + row_est: impl Fn(&QuerySet, &Tx) -> u64, auth: &AuthCtx, config: &DatabaseConfig, ) -> Result<(), DBError> { @@ -403,11 +439,11 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { let table_access = query.source.table_access(); tracing::trace!(table = query.source.table_name()); - let result = build_query(self.ctx, self.db, self.tx, query, &mut |id| { + let head = query.head().clone(); + let rows = build_query(self.ctx, self.db, self.tx, query, &mut |id| { sources.take(id).map(|mt| mt.into_iter().map(RelValue::Projection)) - })?; - let head = result.head().clone(); - let rows = result.collect_vec(|row| row.into_product_value())?; + }) + .collect_vec(|row| row.into_product_value()); Ok(Code::Table(MemTable::new(head, table_access, rows))) } @@ -423,7 +459,7 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { fn _execute_update( &mut self, delete: &QueryExpr, - mut assigns: HashMap, + mut assigns: IntMap, sources: Sources<'_, N>, ) -> Result { let result = self._eval_query(delete, sources)?; @@ -441,7 +477,11 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { // Replace the columns in the matched rows with the assigned // values. No typechecking is performed here, nor that all // assignments are consumed. - let exprs: Vec> = table.head.fields.iter().map(|col| assigns.remove(&col.field)).collect(); + let exprs: Vec> = (0..table.head.fields.len()) + .map(ColId::from) + .map(|c| assigns.remove(&c)) + .collect(); + let insert_rows = deleted .data .into_iter() @@ -450,7 +490,7 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { .into_iter() .zip(&exprs) .map(|(val, expr)| { - if let Some(FieldExpr::Value(assigned)) = expr { + if let Some(ColExpr::Value(assigned)) = expr { assigned.clone() } else { val @@ -534,7 +574,7 @@ impl ProgramVm for DbProgram<'_, '_> { CrudExpr::Insert { table, rows } => self._execute_insert(&table, rows), CrudExpr::Update { delete, assignments } => self._execute_update(&delete, assignments, sources), CrudExpr::Delete { query } => self._delete_query(&query, sources), - CrudExpr::CreateTable { table } => self._create_table(table), + CrudExpr::CreateTable { table } => self._create_table(*table), CrudExpr::Drop { name, kind, .. } => self._drop(&name, kind), CrudExpr::SetVar { name, value } => self._set_config(name, value), CrudExpr::ReadVar { name } => self._read_config(name), @@ -542,26 +582,6 @@ impl ProgramVm for DbProgram<'_, '_> { } } -impl<'a> RelOps<'a> for TableCursor<'a> { - fn head(&self) -> &Arc
{ - &self.table.head - } - - fn next(&mut self) -> Result>, ErrorVm> { - Ok(self.iter.next().map(RelValue::Row)) - } -} - -impl<'a, R: RangeBounds> RelOps<'a> for IndexCursor<'a, R> { - fn head(&self) -> &Arc
{ - &self.table.head - } - - fn next(&mut self) -> Result>, ErrorVm> { - Ok(self.iter.next().map(RelValue::Row)) - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; @@ -575,11 +595,12 @@ pub(crate) mod tests { use spacetimedb_lib::error::ResultTest; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::db::def::{ColumnDef, IndexDef, IndexType, TableSchema}; - use spacetimedb_sats::relation::FieldName; + use spacetimedb_sats::relation::{FieldName, Header}; use spacetimedb_sats::{product, AlgebraicType, ProductType, ProductValue}; use spacetimedb_vm::eval::run_ast; use spacetimedb_vm::eval::test_helpers::{mem_table, mem_table_one_u64, scalar}; use spacetimedb_vm::operator::OpCmp; + use std::sync::Arc; pub(crate) fn create_table_with_rows( db: &RelationalDB, @@ -697,11 +718,13 @@ pub(crate) mod tests { let stdb = TestDB::durable()?; let schema = &*stdb.schema_for_table(&stdb.begin_tx(), ST_TABLES_ID).unwrap(); - let q = QueryExpr::new(schema).with_select_cmp( - OpCmp::Eq, - FieldName::new(ST_TABLES_ID, StTableFields::TableName.into()), - scalar(ST_TABLES_NAME), - ); + let q = QueryExpr::new(schema) + .with_select_cmp( + OpCmp::Eq, + FieldName::new(ST_TABLES_ID, StTableFields::TableName.into()), + scalar(ST_TABLES_NAME), + ) + .unwrap(); let st_table_row = StTableRow { table_id: ST_TABLES_ID, table_name: ST_TABLES_NAME.into(), @@ -725,11 +748,13 @@ pub(crate) mod tests { FieldName::new(ST_COLUMNS_ID, StColumnFields::TableId.into()), scalar(ST_COLUMNS_ID), ) + .unwrap() .with_select_cmp( OpCmp::Eq, FieldName::new(ST_COLUMNS_ID, StColumnFields::ColPos.into()), scalar(StColumnFields::TableId as u32), - ); + ) + .unwrap(); let st_column_row = StColumnRow { table_id: ST_COLUMNS_ID, col_pos: StColumnFields::TableId.col_id(), @@ -754,11 +779,13 @@ pub(crate) mod tests { let index_id = db.with_auto_commit(&ctx, |tx| db.create_index(tx, table_id, index))?; let indexes_schema = &*db.schema_for_table(&db.begin_tx(), ST_INDEXES_ID).unwrap(); - let q = QueryExpr::new(indexes_schema).with_select_cmp( - OpCmp::Eq, - FieldName::new(ST_INDEXES_ID, StIndexFields::IndexName.into()), - scalar("idx_1"), - ); + let q = QueryExpr::new(indexes_schema) + .with_select_cmp( + OpCmp::Eq, + FieldName::new(ST_INDEXES_ID, StIndexFields::IndexName.into()), + scalar("idx_1"), + ) + .unwrap(); let st_index_row = StIndexRow { index_id, index_name: "idx_1".into(), @@ -778,11 +805,13 @@ pub(crate) mod tests { let db = TestDB::durable()?; let schema = &*db.schema_for_table(&db.begin_tx(), ST_SEQUENCES_ID).unwrap(); - let q = QueryExpr::new(schema).with_select_cmp( - OpCmp::Eq, - FieldName::new(ST_SEQUENCES_ID, StSequenceFields::TableId.into()), - scalar(ST_SEQUENCES_ID), - ); + let q = QueryExpr::new(schema) + .with_select_cmp( + OpCmp::Eq, + FieldName::new(ST_SEQUENCES_ID, StSequenceFields::TableId.into()), + scalar(ST_SEQUENCES_ID), + ) + .unwrap(); let st_sequence_row = StSequenceRow { sequence_id: 3.into(), sequence_name: "seq_st_sequence_sequence_id_primary_key_auto".into(), diff --git a/crates/data-structures/src/map.rs b/crates/data-structures/src/map.rs index 5ff7cdfbc1..521b1f8f58 100644 --- a/crates/data-structures/src/map.rs +++ b/crates/data-structures/src/map.rs @@ -8,6 +8,10 @@ pub use nohash_hasher::IsEnabled as ValidAsIdentityHash; /// which is valid for any key type that can be converted to a `u64` without truncation. pub type IntMap = HashMap>; +/// A version of [`HashSet`] using the identity hash function, +/// which is valid for any key type that can be converted to a `u64` without truncation. +pub type IntSet = HashSet>; + pub trait HashCollectionExt { /// Returns a new collection with default capacity, using `S::default()` to build the hasher. fn new() -> Self; diff --git a/crates/sats/src/db/error.rs b/crates/sats/src/db/error.rs index 5b1de6fab5..b8636e735b 100644 --- a/crates/sats/src/db/error.rs +++ b/crates/sats/src/db/error.rs @@ -1,6 +1,7 @@ use crate::db::def::IndexType; use crate::product_value::InvalidFieldError; use crate::relation::{FieldName, Header}; +use crate::satn::Satn as _; use crate::{buffer, AlgebraicType, AlgebraicValue}; use derive_more::Display; use spacetimedb_primitives::{ColId, ColList, TableId}; @@ -88,10 +89,12 @@ pub enum RelationError { FieldNotFound(Header, FieldName), #[error("Field `{0}` fail to infer the type: {1}")] TypeInference(FieldName, TypeError), + #[error("Field with value `{}` was not a `bool`", val.to_satn())] + NotBoolValue { val: AlgebraicValue }, + #[error("Field `{field}` was expected to be `bool` but is `{}`", ty.to_satn())] + NotBoolType { field: FieldName, ty: AlgebraicType }, #[error("Field declaration only support `table.field` or `field`. It gets instead `{0}`")] FieldPathInvalid(String), - #[error("Field `{1}` not found at position {0}")] - FieldNotFoundAtPos(usize, FieldName), } #[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Display)] diff --git a/crates/sats/src/relation.rs b/crates/sats/src/relation.rs index 25c069df7f..7a7f26069b 100644 --- a/crates/sats/src/relation.rs +++ b/crates/sats/src/relation.rs @@ -31,17 +31,17 @@ impl FieldName { // TODO(perf): Remove `Clone` derivation. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, From)] -pub enum FieldExpr { - Name(FieldName), +pub enum ColExpr { + Col(ColId), Value(AlgebraicValue), } -impl FieldExpr { - /// Returns a borrowed version of `FieldExpr`. - pub fn borrowed(&self) -> FieldExprRef<'_> { +impl ColExpr { + /// Returns a borrowed version of `ColExpr`. + pub fn borrowed(&self) -> ColExprRef<'_> { match self { - Self::Name(x) => FieldExprRef::Name(*x), - Self::Value(x) => FieldExprRef::Value(x), + Self::Col(x) => ColExprRef::Col(*x), + Self::Value(x) => ColExprRef::Value(x), } } } @@ -58,19 +58,19 @@ impl fmt::Display for FieldName { } } -impl fmt::Display for FieldExpr { +impl fmt::Display for ColExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - FieldExpr::Name(x) => write!(f, "{x}"), - FieldExpr::Value(x) => write!(f, "{}", x.to_satn()), + ColExpr::Col(x) => write!(f, "{x}"), + ColExpr::Value(x) => write!(f, "{}", x.to_satn()), } } } /// A borrowed version of `FieldExpr`. #[derive(Clone, Copy)] -pub enum FieldExprRef<'a> { - Name(FieldName), +pub enum ColExprRef<'a> { + Col(ColId), Value(&'a AlgebraicValue), } @@ -156,22 +156,21 @@ impl Header { .any(|(col, ct)| col.contains(field) && ct.contains(&constraint)) } - /// Project the [FieldExpr] & the [Constraints] that referenced them - pub fn project(&self, cols: &[impl Into + Clone]) -> Result { + /// Project the [ColExpr]s & the [Constraints] that referenced them + pub fn project(&self, cols: &[ColExpr]) -> Result { let mut p = Vec::with_capacity(cols.len()); let mut to_keep = ColListBuilder::new(); for (pos, col) in cols.iter().enumerate() { - match col.clone().into() { - FieldExpr::Name(col) => { - let pos = self.column_pos_or_err(col)?; - to_keep.push(pos); - p.push(self.fields[pos.idx()].clone()); + match col { + ColExpr::Col(col) => { + to_keep.push(*col); + p.push(self.fields[col.idx()].clone()); } - FieldExpr::Value(col) => { + ColExpr::Value(val) => { let field = FieldName::new(self.table_id, pos.into()); - let ty = col.type_of().ok_or_else(|| { - RelationError::TypeInference(field, TypeError::CannotInferType { value: col }) + let ty = val.type_of().ok_or_else(|| { + RelationError::TypeInference(field, TypeError::CannotInferType { value: val.clone() }) })?; p.push(Column::new(field, ty)); } @@ -225,36 +224,6 @@ impl fmt::Display for Header { } } -/// An estimate for the range of rows in the [Relation] -#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)] -pub struct RowCount { - pub min: usize, - pub max: Option, -} - -impl RowCount { - pub fn exact(rows: usize) -> Self { - Self { - min: rows, - max: Some(rows), - } - } - - pub fn unknown() -> Self { - Self { min: 0, max: None } - } -} - -/// A [Relation] is anything that could be represented as a [Header] of `[ColumnName:ColumnType]` that -/// generates rows/tuples of [AlgebraicValue] that exactly match that [Header]. -pub trait Relation { - fn head(&self) -> &Arc
; - /// Specify the size in rows of the [Relation]. - /// - /// Warning: It should at least be precise in the lower-bound estimate. - fn row_count(&self) -> RowCount; -} - /// A stored table from [RelationalDB] #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct DbTable { @@ -275,16 +244,6 @@ impl DbTable { } } -impl Relation for DbTable { - fn head(&self) -> &Arc
{ - &self.head - } - - fn row_count(&self) -> RowCount { - RowCount::unknown() - } -} - #[cfg(test)] mod tests { use super::*; @@ -309,12 +268,11 @@ mod tests { #[test] fn test_project() { - let t1 = 0.into(); let a = 0.into(); let b = 1.into(); - let head = head(t1, "t1", (a, b), 0); - let new = head.project(&[] as &[FieldName]).unwrap(); + let head = head(0, "t1", (a, b), 0); + let new = head.project(&[] as &[ColExpr]).unwrap(); let mut empty = head.clone_for_error(); empty.fields.clear(); @@ -322,26 +280,26 @@ mod tests { assert_eq!(empty, new); let all = head.clone_for_error(); - let new = head.project(&[FieldName::new(t1, a), FieldName::new(t1, b)]).unwrap(); + let new = head.project(&[a, b].map(ColExpr::Col)).unwrap(); assert_eq!(all, new); let mut first = head.clone_for_error(); first.fields.pop(); first.constraints = first.retain_constraints(&a.into()); - let new = head.project(&[FieldName::new(t1, a)]).unwrap(); + let new = head.project(&[a].map(ColExpr::Col)).unwrap(); assert_eq!(first, new); let mut second = head.clone_for_error(); second.fields.remove(0); second.constraints = second.retain_constraints(&b.into()); - let new = head.project(&[FieldName::new(t1, b)]).unwrap(); + let new = head.project(&[b].map(ColExpr::Col)).unwrap(); assert_eq!(second, new); } #[test] fn test_extend() { let t1 = 0.into(); - let t2 = 1.into(); + let t2: TableId = 1.into(); let a = 0.into(); let b = 1.into(); let c = 0.into(); @@ -352,13 +310,13 @@ mod tests { let new = head_lhs.extend(&head_rhs); - let lhs = new.project(&[FieldName::new(t1, a), FieldName::new(t1, b)]).unwrap(); + let lhs = new.project(&[a, b].map(ColExpr::Col)).unwrap(); assert_eq!(head_lhs, lhs); let mut head_rhs = head(t2, "t2", (c, d), 2); head_rhs.table_id = t1; head_rhs.table_name = head_lhs.table_name.clone(); - let rhs = new.project(&[FieldName::new(t2, c), FieldName::new(t2, d)]).unwrap(); + let rhs = new.project(&[2, 3].map(ColId).map(ColExpr::Col)).unwrap(); assert_eq!(head_rhs, rhs); } } diff --git a/crates/vm/src/errors.rs b/crates/vm/src/errors.rs index afc0dada2a..873cfa0417 100644 --- a/crates/vm/src/errors.rs +++ b/crates/vm/src/errors.rs @@ -16,8 +16,6 @@ pub enum ConfigError { /// Typing Errors #[derive(Error, Debug)] pub enum ErrorType { - #[error("Field should resolve to `bool`, but it got the value `{{0.to_satn()}}`")] - FieldBool(AlgebraicValue), #[error("Error Parsing `{value}` into type [{ty}]: {err}")] Parse { value: String, ty: String, err: String }, } diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index d3511f9be8..b3551f98c6 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -1,54 +1,19 @@ use crate::errors::ErrorVm; -use crate::expr::{Code, ColumnOp, JoinExpr, ProjectExpr, SourceExpr, SourceSet}; -use crate::expr::{Expr, Query}; -use crate::iterators::RelIter; +use crate::expr::{Code, ColumnOp, Expr, JoinExpr, ProjectExpr, SourceSet}; use crate::program::{ProgramVm, Sources}; use crate::rel_ops::RelOps; use crate::relation::RelValue; -use spacetimedb_sats::relation::Relation; use spacetimedb_sats::ProductValue; pub type IterRows<'a> = dyn RelOps<'a> + 'a; -/// `sources` should be a `Vec` -/// where the `idx`th element is the table referred to in the `query` as `SourceId(idx)`. -/// While constructing the query, the `sources` will be destructively modified with `Option::take` -/// to extract the sources, -/// so the `query` cannot refer to the same `SourceId` multiple times. -pub fn build_query<'a, const N: usize>( - mut result: Box>, - query: &'a [Query], - sources: Sources<'_, N>, -) -> Result>, ErrorVm> { - for q in query { - result = match q { - Query::IndexScan(_) => { - panic!("index scans unsupported on memory tables") - } - Query::IndexJoin(_) => { - panic!("index joins unsupported on memory tables") - } - Query::Select(cmp) => build_select(result, cmp), - Query::Project(proj) => build_project(result, proj), - Query::JoinInner(q) => { - let rhs = build_source_expr_query(sources, &q.rhs.source); - let rhs = build_query(rhs, &q.rhs.query, sources)?; - join_inner(result, rhs, q) - } - }; - } - Ok(result) -} - pub fn build_select<'a>(base: impl RelOps<'a> + 'a, cmp: &'a ColumnOp) -> Box> { - let header = base.head().clone(); - Box::new(base.select(move |row| cmp.compare(row, &header))) + Box::new(base.select(move |row| cmp.compare(row))) } pub fn build_project<'a>(base: impl RelOps<'a> + 'a, proj: &'a ProjectExpr) -> Box> { - let header_before = base.head().clone(); - Box::new(base.project(&proj.header_after, &proj.fields, move |cols, row| { - Ok(RelValue::Projection(row.project_owned(cols, &header_before)?)) + Box::new(base.project(&proj.cols, move |cols, row| { + RelValue::Projection(row.project_owned(cols)) })) } @@ -59,27 +24,13 @@ pub fn join_inner<'a>(lhs: impl RelOps<'a> + 'a, rhs: impl RelOps<'a> + 'a, q: & let key_rhs = move |row: &RelValue<'_>| row.read_column(col_rhs).unwrap().into_owned(); let pred = move |l: &RelValue<'_>, r: &RelValue<'_>| l.read_column(col_lhs) == r.read_column(col_rhs); - if let Some(head) = q.inner.as_ref().cloned() { - Box::new(lhs.join_inner(rhs, head, key_lhs, key_rhs, pred, move |l, r| l.extend(r))) + if q.inner.is_some() { + Box::new(lhs.join_inner(rhs, key_lhs, key_rhs, pred, move |l, r| l.extend(r))) } else { - let head = lhs.head().clone(); - Box::new(lhs.join_inner(rhs, head, key_lhs, key_rhs, pred, move |l, _| l)) + Box::new(lhs.join_inner(rhs, key_lhs, key_rhs, pred, move |l, _| l)) } } -pub(crate) fn build_source_expr_query<'a, const N: usize>( - sources: Sources<'_, N>, - source: &SourceExpr, -) -> Box> { - let source_id = source.source_id().unwrap_or_else(|| todo!("How pass the db iter?")); - let head = source.head().clone(); - let rc = source.row_count(); - let table = sources.take(source_id).unwrap_or_else(|| { - panic!("Query plan specifies in-mem table for {source_id:?}, but found a `DbTable` or nothing") - }); - Box::new(RelIter::new(head, rc, table.into_iter().map(RelValue::Projection))) -} - /// Execute the code pub fn eval(p: &mut P, code: Code, sources: Sources<'_, N>) -> Code { match code { @@ -225,8 +176,8 @@ pub mod tests { use super::test_helpers::*; use super::*; - use crate::expr::{QueryExpr, SourceSet}; - use crate::program::Program; + use crate::expr::{CrudExpr, Query, QueryExpr, SourceExpr, SourceSet}; + use crate::iterators::RelIter; use crate::relation::MemTable; use spacetimedb_lib::operator::{OpCmp, OpLogic}; use spacetimedb_primitives::ColId; @@ -234,26 +185,77 @@ pub mod tests { use spacetimedb_sats::relation::{FieldName, Header}; use spacetimedb_sats::{product, AlgebraicType, ProductType}; - fn run_query(p: &mut Program, ast: Expr, sources: SourceSet, N>) -> MemTable { - match run_ast(p, ast, sources) { + /// From an original source of `result`s, applies `queries` and returns a final set of results. + fn build_query<'a, const N: usize>( + mut result: Box>, + queries: &'a [Query], + sources: Sources<'_, N>, + ) -> Box> { + for q in queries { + result = match q { + Query::IndexScan(_) | Query::IndexJoin(_) => panic!("unsupported on memory tables"), + Query::Select(cmp) => build_select(result, cmp), + Query::Project(proj) => build_project(result, proj), + Query::JoinInner(q) => { + let rhs = build_source_expr_query(sources, &q.rhs.source); + let rhs = build_query(rhs, &q.rhs.query, sources); + join_inner(result, rhs, q) + } + }; + } + result + } + + fn build_source_expr_query<'a, const N: usize>(sources: Sources<'_, N>, source: &SourceExpr) -> Box> { + let source_id = source.source_id().unwrap(); + let table = sources.take(source_id).unwrap(); + Box::new(RelIter::new(table.into_iter().map(RelValue::Projection))) + } + + /// A default program that run in-memory without a database + struct Program; + + impl ProgramVm for Program { + fn eval_query(&mut self, query: CrudExpr, sources: Sources<'_, N>) -> Result { + match query { + CrudExpr::Query(query) => { + let result = build_source_expr_query(sources, &query.source); + let rows = build_query(result, &query.query, sources).collect_vec(|row| row.into_product_value()); + + let head = query.head().clone(); + + Ok(Code::Table(MemTable::new(head, query.source.table_access(), rows))) + } + _ => todo!(), + } + } + } + + fn run_query(ast: Expr, sources: SourceSet, N>) -> MemTable { + match run_ast(&mut Program, ast, sources) { Code::Table(x) => x, x => panic!("Unexpected result on query: {x}"), } } + fn get_field_pos(table: &MemTable, pos: usize) -> FieldName { + *table.head.fields.get(pos).map(|x| &x.field).unwrap() + } + #[test] fn test_select() { - let p = &mut Program; let input = mem_table_one_u64(0.into()); - let field = *input.get_field_pos(0).unwrap(); + let field = get_field_pos(&input, 0); let mut sources = SourceSet::<_, 1>::empty(); let source_expr = sources.add_mem_table(input); - let q = QueryExpr::new(source_expr).with_select_cmp(OpCmp::Eq, field, scalar(1u64)); + let q = QueryExpr::new(source_expr) + .with_select_cmp(OpCmp::Eq, field, scalar(1u64)) + .unwrap(); let head = q.head().clone(); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); let row = product![1u64]; assert_eq!(result, MemTable::from_iter(head, [row]), "Query"); } @@ -267,32 +269,37 @@ pub mod tests { let source_expr = sources.add_mem_table(table.clone()); let source = QueryExpr::new(source_expr); - let field = *table.get_field_pos(0).unwrap(); - let q = source.clone().with_project(&[field.into()], None).unwrap(); + let field = get_field_pos(&table, 0); + let q = source.clone().with_project([field.into()].into(), None).unwrap(); let head = q.head().clone(); let result = run_ast(p, q.into(), sources); let row = product![1u64]; assert_eq!(result, Code::Table(MemTable::from_iter(head.clone(), [row])), "Project"); + } + + #[test] + fn test_project_out_of_bounds() { + let table = mem_table_one_u64(0.into()); let mut sources = SourceSet::<_, 1>::empty(); let source_expr = sources.add_mem_table(table.clone()); let source = QueryExpr::new(source_expr); + // This field is out of bounds of `table`'s header, so `run_ast` will panic. let field = FieldName::new(table.head.table_id, 1.into()); assert!(matches!( - source.with_project(&[field.into()], None).unwrap_err(), - RelationError::FieldNotFound(h, f) if h == *head && f == field, + source.with_project([field.into()].into(), None).unwrap_err(), + RelationError::FieldNotFound(_, f) if f == field, )); } #[test] fn test_join_inner() { - let p = &mut Program; let table_id = 0.into(); let table = mem_table_one_u64(table_id); let col: ColId = 0.into(); - let field = table.head().fields[col.idx()].clone(); + let field = table.head.fields[col.idx()].clone(); let mut sources = SourceSet::<_, 2>::empty(); let source_expr = sources.add_mem_table(table.clone()); @@ -300,7 +307,7 @@ pub mod tests { let q = QueryExpr::new(source_expr).with_join_inner(second_source_expr, col, col, false); dbg!(&q); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); // The expected result. let head = Header::new(table_id, "".into(), [field.clone(), field].into(), Vec::new()); @@ -318,7 +325,6 @@ pub mod tests { #[test] fn test_semijoin() { - let p = &mut Program; let table_id = 0.into(); let table = mem_table_one_u64(table_id); let col = 0.into(); @@ -328,7 +334,7 @@ pub mod tests { let second_source_expr = sources.add_mem_table(table); let q = QueryExpr::new(source_expr).with_join_inner(second_source_expr, col, col, true); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); // The expected result. let inv = ProductType::from([(None, AlgebraicType::U64)]); @@ -346,8 +352,6 @@ pub mod tests { #[test] fn test_query_logic() { - let p = &mut Program; - let inv = ProductType::from([("id", AlgebraicType::U64), ("name", AlgebraicType::String)]); let row = product![1u64, "health"]; @@ -358,18 +362,22 @@ pub mod tests { let mut sources = SourceSet::<_, 1>::empty(); let source_expr = sources.add_mem_table(input.clone()); - let q = QueryExpr::new(source_expr.clone()).with_select_cmp(OpLogic::And, scalar(true), scalar(true)); + let q = QueryExpr::new(source_expr.clone()) + .with_select_cmp(OpLogic::And, scalar(true), scalar(true)) + .unwrap(); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); assert_eq!(result, inv.clone(), "Query And"); let mut sources = SourceSet::<_, 1>::empty(); let source_expr = sources.add_mem_table(input); - let q = QueryExpr::new(source_expr).with_select_cmp(OpLogic::Or, scalar(true), scalar(false)); + let q = QueryExpr::new(source_expr) + .with_select_cmp(OpLogic::Or, scalar(true), scalar(false)) + .unwrap(); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); assert_eq!(result, inv, "Query Or"); } @@ -378,8 +386,6 @@ pub mod tests { /// Inventory /// | id: u64 | name : String | fn test_query_inner_join() { - let p = &mut Program; - let inv = ProductType::from([("id", AlgebraicType::U64), ("name", AlgebraicType::String)]); let row = product![1u64, "health"]; @@ -394,7 +400,7 @@ pub mod tests { let q = QueryExpr::new(source_expr).with_join_inner(second_source_expr, col, col, false); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); //The expected result let inv = ProductType::from([ @@ -411,8 +417,6 @@ pub mod tests { /// Inventory /// | id: u64 | name : String | fn test_query_semijoin() { - let p = &mut Program; - let inv = ProductType::from([("id", AlgebraicType::U64), ("name", AlgebraicType::String)]); let row = product![1u64, "health"]; @@ -427,7 +431,7 @@ pub mod tests { let q = QueryExpr::new(source_expr).with_join_inner(second_source_expr, col, col, true); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); // The expected result. let inv = ProductType::from([(None, AlgebraicType::U64), (Some("name"), AlgebraicType::String)]); @@ -444,16 +448,14 @@ pub mod tests { /// Location /// | entity_id: u64 | x : f32 | z : f32 | fn test_query_game() { - let p = &mut Program; - // See table above. let data = create_game_data(); let inv @ [inv_inventory_id, _] = [0, 1].map(|c| c.into()); - let inv_head = data.inv.head().clone(); + let inv_head = data.inv.head.clone(); let inv_expr = |col: ColId| inv_head.fields[col.idx()].field.into(); let [location_entity_id, location_x, location_z] = [0, 1, 2].map(|c| c.into()); let [player_entity_id, player_inventory_id] = [0, 1].map(|c| c.into()); - let loc_head = data.location.head().clone(); + let loc_head = data.location.head.clone(); let loc_field = |col: ColId| loc_head.fields[col.idx()].field; let inv_table_id = data.inv.head.table_id; let player_table_id = data.player.head.table_id; @@ -472,11 +474,15 @@ pub mod tests { let q = QueryExpr::new(player_source_expr) .with_join_inner(location_source_expr, player_entity_id, location_entity_id, true) .with_select_cmp(OpCmp::Gt, loc_field(location_x), scalar(0.0f32)) + .unwrap() .with_select_cmp(OpCmp::LtEq, loc_field(location_x), scalar(32.0f32)) + .unwrap() .with_select_cmp(OpCmp::Gt, loc_field(location_z), scalar(0.0f32)) - .with_select_cmp(OpCmp::LtEq, loc_field(location_z), scalar(32.0f32)); + .unwrap() + .with_select_cmp(OpCmp::LtEq, loc_field(location_z), scalar(32.0f32)) + .unwrap(); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); let ty = ProductType::from([("entity_id", AlgebraicType::U64), ("inventory_id", AlgebraicType::U64)]); let row1 = product!(100u64, 1u64); @@ -514,13 +520,17 @@ pub mod tests { true, ) .with_select_cmp(OpCmp::Gt, loc_field(location_x), scalar(0.0f32)) + .unwrap() .with_select_cmp(OpCmp::LtEq, loc_field(location_x), scalar(32.0f32)) + .unwrap() .with_select_cmp(OpCmp::Gt, loc_field(location_z), scalar(0.0f32)) + .unwrap() .with_select_cmp(OpCmp::LtEq, loc_field(location_z), scalar(32.0f32)) - .with_project(&inv.map(inv_expr), Some(inv_table_id)) + .unwrap() + .with_project(inv.map(inv_expr).into(), Some(inv_table_id)) .unwrap(); - let result = run_query(p, q.into(), sources); + let result = run_query(q.into(), sources); let ty = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]); let row1 = product!(1u64, "health"); diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index e5c25ee493..a8b18e54d3 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -1,18 +1,19 @@ -use crate::errors::{ErrorKind, ErrorLang, ErrorType, ErrorVm}; +use crate::errors::{ErrorKind, ErrorLang}; use crate::operator::{OpCmp, OpLogic, OpQuery}; use crate::relation::{MemTable, RelValue}; use arrayvec::ArrayVec; use core::slice::from_ref; use derive_more::From; use smallvec::{smallvec, SmallVec}; -use spacetimedb_data_structures::map::{HashMap, HashSet}; -use spacetimedb_lib::Identity; +use spacetimedb_data_structures::map::{HashSet, IntMap}; +use spacetimedb_lib::{AlgebraicType, Identity}; use spacetimedb_primitives::*; use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::db::def::{TableDef, TableSchema}; use spacetimedb_sats::db::error::{AuthError, RelationError}; -use spacetimedb_sats::relation::{DbTable, FieldExpr, FieldName, Header, Relation, RowCount}; +use spacetimedb_sats::relation::{ColExpr, DbTable, FieldName, Header}; +use spacetimedb_sats::satn::Satn; use spacetimedb_sats::ProductValue; use std::cmp::Reverse; use std::collections::btree_map::Entry; @@ -27,9 +28,121 @@ pub trait AuthAccess { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)] -pub enum ColumnOp { +pub enum FieldExpr { + Name(FieldName), + Value(AlgebraicValue), +} + +impl FieldExpr { + pub fn strip_table(self) -> ColExpr { + match self { + Self::Name(field) => ColExpr::Col(field.col), + Self::Value(value) => ColExpr::Value(value), + } + } + + pub fn name_to_col(self, head: &Header) -> Result { + match self { + Self::Value(val) => Ok(ColExpr::Value(val)), + Self::Name(field) => head.column_pos_or_err(field).map(ColExpr::Col), + } + } +} + +impl fmt::Display for FieldExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FieldExpr::Name(x) => write!(f, "{x}"), + FieldExpr::Value(x) => write!(f, "{}", x.to_satn()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)] +pub enum FieldOp { #[from] Field(FieldExpr), + Cmp { + op: OpQuery, + lhs: Box, + rhs: Box, + }, +} + +type FieldOpFlat = SmallVec<[FieldOp; 1]>; + +impl FieldOp { + pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self { + Self::Cmp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + } + } + + pub fn cmp(field: impl Into, op: OpCmp, value: impl Into) -> Self { + Self::new( + OpQuery::Cmp(op), + Self::Field(FieldExpr::Name(field.into())), + Self::Field(FieldExpr::Value(value.into())), + ) + } + + pub fn names_to_cols(self, head: &Header) -> Result { + match self { + Self::Field(field) => field.name_to_col(head).map(ColumnOp::Col), + Self::Cmp { op, lhs, rhs } => { + let lhs = lhs.names_to_cols(head)?; + let rhs = rhs.names_to_cols(head)?; + Ok(ColumnOp::new(op, lhs, rhs)) + } + } + } + + /// Flattens a nested conjunction of AND expressions. + /// + /// For example, `a = 1 AND b = 2 AND c = 3` becomes `[a = 1, b = 2, c = 3]`. + /// + /// This helps with splitting the kinds of `queries`, + /// that *could* be answered by a `index`, + /// from the ones that need to be executed with a `scan`. + pub fn flatten_ands(self) -> FieldOpFlat { + fn fill_vec(buf: &mut FieldOpFlat, op: FieldOp) { + match op { + FieldOp::Cmp { + op: OpQuery::Logic(OpLogic::And), + lhs, + rhs, + } => { + fill_vec(buf, *lhs); + fill_vec(buf, *rhs); + } + op => buf.push(op), + } + } + let mut buf = SmallVec::new(); + fill_vec(&mut buf, self); + buf + } +} + +impl fmt::Display for FieldOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Field(x) => { + write!(f, "{}", x) + } + Self::Cmp { op, lhs, rhs } => { + write!(f, "{} {} {}", lhs, op, rhs) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)] +pub enum ColumnOp { + #[from] + Col(ColExpr), Cmp { op: OpQuery, lhs: Box, @@ -37,11 +150,10 @@ pub enum ColumnOp { }, } -type ColumnOpFlat = SmallVec<[ColumnOp; 1]>; type ColumnOpRefFlat<'a> = SmallVec<[&'a ColumnOp; 1]>; impl ColumnOp { - pub fn new(op: OpQuery, lhs: ColumnOp, rhs: ColumnOp) -> Self { + pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self { Self::Cmp { op, lhs: Box::new(lhs), @@ -49,25 +161,22 @@ impl ColumnOp { } } - pub fn cmp(field: impl Into, op: OpCmp, value: impl Into) -> Self { + pub fn cmp(col: impl Into, op: OpCmp, value: impl Into) -> Self { Self::new( OpQuery::Cmp(op), - ColumnOp::Field(FieldExpr::Name(field.into())), - ColumnOp::Field(FieldExpr::Value(value.into())), + Self::Col(ColExpr::Col(col.into())), + Self::Col(ColExpr::Value(value.into())), ) } /// Returns a new op where `lhs` and `rhs` are logically AND-ed together. - fn and(lhs: ColumnOp, rhs: ColumnOp) -> Self { + fn and(lhs: Self, rhs: Self) -> Self { Self::new(OpQuery::Logic(OpLogic::And), lhs, rhs) } /// Returns an op where `col_i op value_i` are all `AND`ed together. - fn and_cmp(op: OpCmp, head: &Header, cols: &ColList, value: AlgebraicValue) -> Self { - let eq = |(col, value): (ColId, _)| { - let field = head.fields[col.idx()].field; - Self::cmp(field, op, value) - }; + fn and_cmp(op: OpCmp, cols: &ColList, value: AlgebraicValue) -> Self { + let eq = |(col, value): (ColId, _)| Self::cmp(col, op, value); // For singleton constraints, the `value` must be used directly. if cols.is_singleton() { @@ -84,11 +193,7 @@ impl ColumnOp { /// Returns an op where `cols` must be within bounds. /// This handles both the case of single-col bounds and multi-col bounds. - fn from_op_col_bounds( - head: &Header, - cols: &ColList, - bounds: (Bound, Bound), - ) -> Self { + fn from_op_col_bounds(cols: &ColList, bounds: (Bound, Bound)) -> Self { let (cmp, value) = match bounds { // Equality; field <= value && field >= value <=> field = value (Bound::Included(a), Bound::Included(b)) if a == b => (OpCmp::Eq, a), @@ -102,105 +207,64 @@ impl ColumnOp { (Bound::Unbounded, Bound::Excluded(value)) => (OpCmp::Lt, value), (Bound::Unbounded, Bound::Unbounded) => unreachable!(), (lower_bound, upper_bound) => { - let lhs = Self::from_op_col_bounds(head, cols, (lower_bound, Bound::Unbounded)); - let rhs = Self::from_op_col_bounds(head, cols, (Bound::Unbounded, upper_bound)); + let lhs = Self::from_op_col_bounds(cols, (lower_bound, Bound::Unbounded)); + let rhs = Self::from_op_col_bounds(cols, (Bound::Unbounded, upper_bound)); return ColumnOp::and(lhs, rhs); } }; - ColumnOp::and_cmp(cmp, head, cols, value) + ColumnOp::and_cmp(cmp, cols, value) } - fn reduce(&self, row: &RelValue<'_>, value: &ColumnOp, header: &Header) -> Result { + fn reduce(&self, row: &RelValue<'_>, value: &Self) -> AlgebraicValue { match value { - ColumnOp::Field(field) => Ok(row.get(field.borrowed(), header)?.into_owned()), - ColumnOp::Cmp { op, lhs, rhs } => Ok(self.compare_bin_op(row, *op, lhs, rhs, header)?.into()), + Self::Col(field) => row.get(field.borrowed()).into_owned(), + Self::Cmp { op, lhs, rhs } => self.compare_bin_op(row, *op, lhs, rhs).into(), } } - fn reduce_bool(&self, row: &RelValue<'_>, value: &ColumnOp, header: &Header) -> Result { + fn reduce_bool(&self, row: &RelValue<'_>, value: &Self) -> bool { match value { - ColumnOp::Field(field) => { - let field = row.get(field.borrowed(), header)?; - - match field.as_bool() { - Some(b) => Ok(*b), - None => Err(ErrorType::FieldBool(field.into_owned()).into()), - } - } - ColumnOp::Cmp { op, lhs, rhs } => Ok(self.compare_bin_op(row, *op, lhs, rhs, header)?), + Self::Col(field) => *row.get(field.borrowed()).as_bool().unwrap(), + Self::Cmp { op, lhs, rhs } => self.compare_bin_op(row, *op, lhs, rhs), } } - fn compare_bin_op( - &self, - row: &RelValue<'_>, - op: OpQuery, - lhs: &ColumnOp, - rhs: &ColumnOp, - header: &Header, - ) -> Result { + fn compare_bin_op(&self, row: &RelValue<'_>, op: OpQuery, lhs: &Self, rhs: &Self) -> bool { match op { OpQuery::Cmp(op) => { - let lhs = self.reduce(row, lhs, header)?; - let rhs = self.reduce(row, rhs, header)?; - - Ok(match op { + let lhs = self.reduce(row, lhs); + let rhs = self.reduce(row, rhs); + match op { OpCmp::Eq => lhs == rhs, OpCmp::NotEq => lhs != rhs, OpCmp::Lt => lhs < rhs, OpCmp::LtEq => lhs <= rhs, OpCmp::Gt => lhs > rhs, OpCmp::GtEq => lhs >= rhs, - }) + } } OpQuery::Logic(op) => { - let lhs = self.reduce_bool(row, lhs, header)?; - let rhs = self.reduce_bool(row, rhs, header)?; + let lhs = self.reduce_bool(row, lhs); + let rhs = self.reduce_bool(row, rhs); - Ok(match op { + match op { OpLogic::And => lhs && rhs, OpLogic::Or => lhs || rhs, - }) + } } } } - pub fn compare(&self, row: &RelValue<'_>, header: &Header) -> Result { + pub fn compare(&self, row: &RelValue<'_>) -> bool { match self { - ColumnOp::Field(field) => { - let lhs = row.get(field.borrowed(), header)?; - Ok(*lhs.as_bool().unwrap()) + Self::Col(field) => { + let lhs = row.get(field.borrowed()); + *lhs.as_bool().unwrap() } - ColumnOp::Cmp { op, lhs, rhs } => self.compare_bin_op(row, *op, lhs, rhs, header), + Self::Cmp { op, lhs, rhs } => self.compare_bin_op(row, *op, lhs, rhs), } } - /// Flattens a nested conjunction of AND expressions. - /// - /// For example, `a = 1 AND b = 2 AND c = 3` becomes `[a = 1, b = 2, c = 3]`. - /// - /// This helps with splitting the kinds of `queries`, - /// that *could* be answered by a `index`, - /// from the ones that need to be executed with a `scan`. - pub fn flatten_ands(self) -> ColumnOpFlat { - fn fill_vec(buf: &mut ColumnOpFlat, op: ColumnOp) { - match op { - ColumnOp::Cmp { - op: OpQuery::Logic(OpLogic::And), - lhs, - rhs, - } => { - fill_vec(buf, *lhs); - fill_vec(buf, *rhs); - } - op => buf.push(op), - } - } - let mut buf = SmallVec::new(); - fill_vec(&mut buf, self); - buf - } - /// Flattens a nested conjunction of AND expressions. /// /// For example, `a = 1 AND b = 2 AND c = 3` becomes `[a = 1, b = 2, c = 3]`. @@ -231,7 +295,7 @@ impl ColumnOp { impl fmt::Display for ColumnOp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ColumnOp::Field(x) => { + ColumnOp::Col(x) => { write!(f, "{}", x) } ColumnOp::Cmp { op, lhs, rhs } => { @@ -241,22 +305,21 @@ impl fmt::Display for ColumnOp { } } -impl From for ColumnOp { - fn from(value: FieldName) -> Self { - ColumnOp::Field(value.into()) +impl From for ColumnOp { + fn from(value: ColId) -> Self { + ColumnOp::Col(value.into()) } } - impl From for ColumnOp { fn from(value: AlgebraicValue) -> Self { - ColumnOp::Field(value.into()) + Self::Col(value.into()) } } impl From for Option { fn from(value: Query) -> Self { match value { - Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.table.head, &op.columns, op.bounds)), + Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.columns, op.bounds)), Query::Select(op) => Some(op), _ => None, } @@ -412,9 +475,8 @@ impl SourceSet, N> { /// Insert a [`MemTable`] into this `SourceSet` so it can be used in a query plan, /// and return a [`SourceExpr`] which can be embedded in that plan. pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr { - let len = table.data.len(); let id = self.add(table.data); - SourceExpr::from_mem_table(table.head, table.table_access, len, id) + SourceExpr::from_mem_table(table.head, table.table_access, id) } } @@ -434,7 +496,6 @@ pub enum SourceExpr { header: Arc
, table_type: StTableType, table_access: StAccess, - row_count: RowCount, }, /// A plan for a database table. Because [`DbTable`] is small and efficiently cloneable, /// no indirection into a [`SourceSet`] is required. @@ -487,13 +548,12 @@ impl SourceExpr { matches!(self, SourceExpr::DbTable(_)) } - pub fn from_mem_table(header: Arc
, table_access: StAccess, row_count: usize, id: SourceId) -> Self { + pub fn from_mem_table(header: Arc
, table_access: StAccess, id: SourceId) -> Self { SourceExpr::InMemory { source_id: id, header, table_type: StTableType::User, table_access, - row_count: RowCount::exact(row_count), } } @@ -519,19 +579,6 @@ impl SourceExpr { } } -impl Relation for SourceExpr { - fn head(&self) -> &Arc
{ - self.head() - } - - fn row_count(&self) -> RowCount { - match self { - SourceExpr::InMemory { row_count, .. } => *row_count, - SourceExpr::DbTable(_) => RowCount::unknown(), - } - } -} - impl From<&TableSchema> for SourceExpr { fn from(value: &TableSchema) -> Self { SourceExpr::DbTable(value.into()) @@ -725,13 +772,13 @@ pub enum CrudExpr { }, Update { delete: QueryExpr, - assignments: HashMap, + assignments: IntMap, }, Delete { query: QueryExpr, }, CreateTable { - table: TableDef, + table: Box, }, Drop { name: String, @@ -781,7 +828,7 @@ impl IndexScan { /// A projection operation in a query. #[derive(Debug, Clone, Eq, PartialEq, From, Hash)] pub struct ProjectExpr { - pub fields: Vec, + pub cols: Vec, // The table id for a qualified wildcard project, if any. // If present, further optimizations are possible. pub wildcard_table: Option, @@ -881,34 +928,34 @@ fn make_index_arg(cmp: OpCmp, columns: &ColList, value: AlgebraicValue) -> Index } #[derive(Debug)] -struct FieldValue<'a> { +struct ColValue<'a> { parent: &'a ColumnOp, + col: ColId, cmp: OpCmp, - field: FieldName, value: &'a AlgebraicValue, } -impl<'a> FieldValue<'a> { - pub fn new(parent: &'a ColumnOp, cmp: OpCmp, field: FieldName, value: &'a AlgebraicValue) -> Self { +impl<'a> ColValue<'a> { + pub fn new(parent: &'a ColumnOp, col: ColId, cmp: OpCmp, value: &'a AlgebraicValue) -> Self { Self { parent, + col, cmp, - field, value, } } } type IndexColumnOpSink<'a> = SmallVec<[IndexColumnOp<'a>; 1]>; -type FieldsIndexed = HashSet<(FieldName, OpCmp)>; +type ColsIndexed = HashSet<(ColId, OpCmp)>; -/// Pick the best indices that can serve the constraints in `fields` +/// Pick the best indices that can serve the constraints in `cols` /// where the indices are taken from `header`. /// /// This function is designed to handle complex scenarios when selecting the optimal index for a query. /// The scenarios include: /// -/// - Combinations of multi- and single-column indexes that could refer to the same field. +/// - Combinations of multi- and single-column indexes that could refer to the same column. /// For example, the table could have indexes `[a]` and `[a, b]]` /// and a user could query for `WHERE a = 1 AND b = 2 AND a = 3`. /// @@ -927,7 +974,7 @@ type FieldsIndexed = HashSet<(FieldName, OpCmp)>; /// /// - A vector of `ScanOrIndex` representing the selected `index` OR `scan` operations. /// -/// - A HashSet of `(FieldName, OpCmp)` representing the fields +/// - A HashSet of `(ColId, OpCmp)` representing the columns /// and operators that can be served by an index. /// /// This is required to remove the redundant operation on e.g., @@ -951,7 +998,7 @@ type FieldsIndexed = HashSet<(FieldName, OpCmp)>; /// would give us two separate `IndexScan`s. /// However, the upper layers of `QueryExpr` building will convert both of those into `Select`s. fn select_best_index<'a>( - fields_indexed: &mut FieldsIndexed, + cols_indexed: &mut ColsIndexed, header: &'a Header, ops: &[&'a ColumnOp], ) -> IndexColumnOpSink<'a> { @@ -968,17 +1015,17 @@ fn select_best_index<'a>( let mut found: IndexColumnOpSink = IndexColumnOpSink::new(); - // Collect fields into a multi-map `(col_id, cmp) -> [field]`. + // Collect fields into a multi-map `(col_id, cmp) -> [col value]`. // This gives us `log(N)` seek + deletion. // TODO(Centril): Consider https://docs.rs/small-map/0.1.3/small_map/enum.SmallMap.html - let mut fields_map = BTreeMap::<_, SmallVec<[_; 1]>>::new(); - extract_fields(ops, header, &mut fields_map, &mut found); + let mut col_map = BTreeMap::<_, SmallVec<[_; 1]>>::new(); + extract_cols(ops, &mut col_map, &mut found); // Go through each index, - // consuming all field constraints that can be served by an index. + // consuming all column constraints that can be served by an index. for col_list in indices { - // (1) No fields left? We're done. - if fields_map.is_empty() { + // (1) No columns left? We're done. + if col_map.is_empty() { break; } @@ -989,11 +1036,9 @@ fn select_best_index<'a>( for cmp in [OpCmp::Eq, OpCmp::Lt, OpCmp::LtEq, OpCmp::Gt, OpCmp::GtEq] { // For a single column index, // we want to avoid the `ProductValue` indirection of below. - for FieldValue { cmp, value, field, .. } in - fields_map.remove(&(col_list.head(), cmp)).into_iter().flatten() - { + for ColValue { cmp, value, col, .. } in col_map.remove(&(col_list.head(), cmp)).into_iter().flatten() { found.push(make_index_arg(cmp, col_list, value.clone())); - fields_indexed.insert((field, cmp)); + cols_indexed.insert((col, cmp)); } } } else { @@ -1009,7 +1054,7 @@ fn select_best_index<'a>( // Compute the minimum number of `=` constraints that every column in the index has. let mut min_all_cols_num_eq = col_list .iter() - .map(|col| fields_map.get(&(col, cmp)).map_or(0, |fs| fs.len())) + .map(|col| col_map.get(&(col, cmp)).map_or(0, |fs| fs.len())) .min() .unwrap_or_default(); @@ -1019,10 +1064,10 @@ fn select_best_index<'a>( let mut elems = Vec::with_capacity(col_list.len() as usize); for col in col_list.iter() { // Cannot panic as `min_all_cols_num_eq > 0`. - let field = pop_multimap(&mut fields_map, (col, cmp)).unwrap(); - fields_indexed.insert((field.field, cmp)); - // Add the field value to the product value. - elems.push(field.value.clone()); + let col_val = pop_multimap(&mut col_map, (col, cmp)).unwrap(); + cols_indexed.insert((col_val.col, cmp)); + // Add the column value to the product value. + elems.push(col_val.value.clone()); } // Construct the index scan. let value = AlgebraicValue::product(elems); @@ -1034,7 +1079,7 @@ fn select_best_index<'a>( // The remaining constraints must be served by a scan. found.extend( - fields_map + col_map .into_iter() .flat_map(|(_, fs)| fs) .map(|f| IndexColumnOp::Scan(f.parent)), @@ -1057,46 +1102,38 @@ fn pop_multimap(map: &mut BTreeMap( - header: &'a Header, - lhs: &'a ColumnOp, - rhs: &'a ColumnOp, -) -> Option<(ColId, FieldName, &'a AlgebraicValue)> { - if let (ColumnOp::Field(FieldExpr::Name(name)), ColumnOp::Field(FieldExpr::Value(val))) = (lhs, rhs) { - return header.field_name(*name).map(|(id, col)| (id, col, val)); +/// Extracts `name = val` when `lhs` is a col and `rhs` is a value. +fn ext_field_val<'a>(lhs: &'a ColumnOp, rhs: &'a ColumnOp) -> Option<(ColId, &'a AlgebraicValue)> { + if let (ColumnOp::Col(ColExpr::Col(col)), ColumnOp::Col(ColExpr::Value(val))) = (lhs, rhs) { + return Some((*col, val)); } None } -/// Extracts `name = val` when `op` is `name = val` and `name` exists. -fn ext_cmp_field_val<'a>( - header: &'a Header, - op: &'a ColumnOp, -) -> Option<(&'a OpCmp, ColId, FieldName, &'a AlgebraicValue)> { +/// Extracts `name = val` when `op` is `name = val`. +fn ext_cmp_field_val(op: &ColumnOp) -> Option<(OpCmp, ColId, &AlgebraicValue)> { match op { ColumnOp::Cmp { op: OpQuery::Cmp(op), lhs, rhs, - } => ext_field_val(header, lhs, rhs).map(|(id, f, v)| (op, id, f, v)), + } => ext_field_val(lhs, rhs).map(|(id, v)| (*op, id, v)), _ => None, } } -/// Extracts a list of `field = val` constraints that *could* be answered by an index -/// and populates those into `fields_map`. -/// The [`ColumnOp`]s that don't fit `field = val` +/// Extracts a list of `col = val` constraints that *could* be answered by an index +/// and populates those into `col_map`. +/// The [`ColumnOp`]s that don't fit `col = val` /// are made into [`IndexColumnOp::Scan`]s immediately which are added to `found`. -fn extract_fields<'a>( +fn extract_cols<'a>( ops: &[&'a ColumnOp], - header: &'a Header, - fields_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[FieldValue<'a>; 1]>>, + col_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[ColValue<'a>; 1]>>, found: &mut IndexColumnOpSink<'a>, ) { - let mut add_field = |parent, op, field_col, field, val| { - let fv = FieldValue::new(parent, op, field, val); - fields_map.entry((field_col, op)).or_default().push(fv); + let mut add_field = |parent, op, col, val| { + let fv = ColValue::new(parent, col, op, val); + col_map.entry((col, op)).or_default().push(fv); }; for op in ops { @@ -1106,9 +1143,9 @@ fn extract_fields<'a>( lhs, rhs, } => { - if let Some((field_col, field, val)) = ext_field_val(header, lhs, rhs) { + if let Some((field_col, val)) = ext_field_val(lhs, rhs) { // `lhs` must be a field that exists and `rhs` must be a value. - add_field(op, *cmp, field_col, field, val); + add_field(op, *cmp, field_col, val); continue; } } @@ -1117,11 +1154,11 @@ fn extract_fields<'a>( lhs, rhs, } => { - if let Some((op_lhs, col_lhs_id, col_lhs, val_lhs)) = ext_cmp_field_val(header, lhs) { - if let Some((op_rhs, col_rhs_id, col_rhs, val_rhs)) = ext_cmp_field_val(header, rhs) { + if let Some((op_lhs, col_lhs, val_lhs)) = ext_cmp_field_val(lhs) { + if let Some((op_rhs, col_rhs, val_rhs)) = ext_cmp_field_val(rhs) { // Both lhs and rhs columns must exist. - add_field(op, *op_lhs, col_lhs_id, col_lhs, val_lhs); - add_field(op, *op_rhs, col_rhs_id, col_rhs, val_rhs); + add_field(op, op_lhs, col_lhs, val_lhs); + add_field(op, op_rhs, col_rhs, val_rhs); continue; } } @@ -1130,7 +1167,7 @@ fn extract_fields<'a>( op: OpQuery::Logic(OpLogic::Or), .. } - | ColumnOp::Field(_) => {} + | ColumnOp::Col(_) => {} } found.push(IndexColumnOp::Scan(op)); @@ -1140,7 +1177,7 @@ fn extract_fields<'a>( /// Sargable stands for Search ARGument ABLE. /// A sargable predicate is one that can be answered using an index. fn find_sargable_ops<'a>( - fields_indexed: &mut FieldsIndexed, + fields_indexed: &mut ColsIndexed, header: &'a Header, op: &'a ColumnOp, ) -> SmallVec<[IndexColumnOp<'a>; 1]> { @@ -1148,7 +1185,7 @@ fn find_sargable_ops<'a>( if ops_flat.len() == 1 { match ops_flat.swap_remove(0) { // Special case; fast path for a single field. - op @ ColumnOp::Field(_) => smallvec![IndexColumnOp::Scan(op)], + op @ ColumnOp::Col(_) => smallvec![IndexColumnOp::Scan(op)], op => select_best_index(fields_indexed, header, &[op]), } } else { @@ -1203,7 +1240,7 @@ impl QueryExpr { .rev() .find_map(|op| match op { Query::Select(_) => None, - Query::IndexScan(scan) => Some(scan.table.head()), + Query::IndexScan(scan) => Some(&scan.table.head), Query::IndexJoin(join) if join.return_index_rows => Some(join.index_side.head()), Query::IndexJoin(join) => Some(join.probe_side.head()), Query::Project(proj) => Some(&proj.header_after), @@ -1268,14 +1305,14 @@ impl QueryExpr { } // merge with a preceding select Query::Select(filter) => { - let op = ColumnOp::and_cmp(OpCmp::Eq, &table.head, &columns, value); + let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value); self.query.push(Query::Select(ColumnOp::and(filter, op))); self } // else generate a new select query => { self.query.push(query); - let op = ColumnOp::and_cmp(OpCmp::Eq, &table.head, &columns, value); + let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value); self.query.push(Query::Select(op)); self } @@ -1366,7 +1403,7 @@ impl QueryExpr { // merge with a preceding select Query::Select(filter) => { let bounds = (Self::bound(value, inclusive), Bound::Unbounded); - let op = ColumnOp::from_op_col_bounds(&table.head, &columns, bounds); + let op = ColumnOp::from_op_col_bounds(&columns, bounds); self.query.push(Query::Select(ColumnOp::and(filter, op))); self } @@ -1374,7 +1411,7 @@ impl QueryExpr { query => { self.query.push(query); let bounds = (Self::bound(value, inclusive), Bound::Unbounded); - let op = ColumnOp::from_op_col_bounds(&table.head, &columns, bounds); + let op = ColumnOp::from_op_col_bounds(&columns, bounds); self.query.push(Query::Select(op)); self } @@ -1468,7 +1505,7 @@ impl QueryExpr { // merge with a preceding select Query::Select(filter) => { let bounds = (Bound::Unbounded, Self::bound(value, inclusive)); - let op = ColumnOp::from_op_col_bounds(&table.head, &columns, bounds); + let op = ColumnOp::from_op_col_bounds(&columns, bounds); self.query.push(Query::Select(ColumnOp::and(filter, op))); self } @@ -1476,92 +1513,194 @@ impl QueryExpr { query => { self.query.push(query); let bounds = (Bound::Unbounded, Self::bound(value, inclusive)); - let op = ColumnOp::from_op_col_bounds(&table.head, &columns, bounds); + let op = ColumnOp::from_op_col_bounds(&columns, bounds); self.query.push(Query::Select(op)); self } } } - pub fn with_select(mut self, op: O) -> Self + pub fn with_select(mut self, op: O) -> Result where - O: Into, + O: Into, { + let op = op.into(); let Some(query) = self.query.pop() else { - self.query.push(Query::Select(op.into())); - return self; + return self.add_base_select(op); }; - match (query, op.into()) { + match (query, op) { ( Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, - inner: semi, + inner, }), - ColumnOp::Cmp { + FieldOp::Cmp { op: OpQuery::Cmp(cmp), lhs: field, rhs: value, }, ) => match (*field, *value) { - (ColumnOp::Field(FieldExpr::Name(field)), ColumnOp::Field(FieldExpr::Value(value))) + (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value))) // Field is from lhs, so push onto join's left arg if self.head().column_pos(field).is_some() => { - self = self.with_select(ColumnOp::cmp(field, cmp, value)); - self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner: semi})); - self + // No typing restrictions on `field cmp value`, + // and there are no binary operators to recurse into. + self = self.with_select(FieldOp::cmp(field, cmp, value))?; + self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner })); + Ok(self) } - (ColumnOp::Field(FieldExpr::Name(field)), ColumnOp::Field(FieldExpr::Value(value))) + (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value))) // Field is from rhs, so push onto join's right arg if rhs.head().column_pos(field).is_some() => { + // No typing restrictions on `field cmp value`, + // and there are no binary operators to recurse into. + let rhs = rhs.with_select(FieldOp::cmp(field, cmp, value))?; self.query.push(Query::JoinInner(JoinExpr { - rhs: rhs.with_select(ColumnOp::cmp(field, cmp, value)), + rhs, col_lhs, col_rhs, - inner: semi, + inner, })); - self + Ok(self) } (field, value) => { - self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner: semi, })); - self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), field, value))); - self + self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner, })); + + // As we have `field op value` we need not demand `bool`, + // but we must still recuse into each side. + self.check_field_op_logics(&field)?; + self.check_field_op_logics(&value)?; + // Convert to `ColumnOp`. + let col = field.names_to_cols(self.head()).unwrap(); + let value = value.names_to_cols(self.head()).unwrap(); + // Add `col op value` filter to query. + self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), col, value))); + Ok(self) } }, - (Query::Select(filter), op) => { - self.query.push(Query::Select(ColumnOp::and(filter, op))); - self + // We have a previous filter `lhs`, so join with `rhs` forming `lhs AND rhs`. + (Query::Select(lhs), rhs) => { + // Type check `rhs`, demanding `bool`. + self.check_field_op(&rhs)?; + // Convert to `ColumnOp`. + let rhs = rhs.names_to_cols(self.head()).unwrap(); + // Add `lhs AND op` to query. + self.query.push(Query::Select(ColumnOp::and(lhs, rhs))); + Ok(self) } + // No previous filter, so add a base one. (query, op) => { self.query.push(query); - self.query.push(Query::Select(op)); - self + self.add_base_select(op) + } + } + } + + /// Add a base `Select` query that filters according to `op`. + /// The `op` is checked to produce a `bool` value. + fn add_base_select(mut self, op: FieldOp) -> Result { + // Type check the filter, demanding `bool`. + self.check_field_op(&op)?; + // Convert to `ColumnOp`. + let op = op.names_to_cols(self.head()).unwrap(); + // Add the filter. + self.query.push(Query::Select(op)); + Ok(self) + } + + /// Type checks a `FieldOp` with respect to `self`, + /// ensuring that query evaluation cannot get stuck or panic due to `reduce_bool`. + fn check_field_op(&self, op: &FieldOp) -> Result<(), RelationError> { + use OpQuery::*; + match op { + // `lhs` and `rhs` must both be typed at `bool`. + FieldOp::Cmp { op: Logic(_), lhs, rhs } => { + self.check_field_op(lhs)?; + self.check_field_op(rhs)?; + Ok(()) + } + // `lhs` and `rhs` have no typing restrictions. + // The result of `lhs op rhs` will always be a `bool` + // either by `Eq` or `Ord` on `AlgebraicValue` (see `ColumnOp::compare_bin_op`). + // However, we still have to recurse into `lhs` and `rhs` + // in case we have e.g., `a == (b == c)`. + FieldOp::Cmp { op: Cmp(_), lhs, rhs } => { + self.check_field_op_logics(lhs)?; + self.check_field_op_logics(rhs)?; + Ok(()) + } + FieldOp::Field(FieldExpr::Value(AlgebraicValue::Bool(_))) => Ok(()), + FieldOp::Field(FieldExpr::Value(v)) => Err(RelationError::NotBoolValue { val: v.clone() }), + FieldOp::Field(FieldExpr::Name(field)) => { + let field = *field; + let head = self.head(); + let col_id = head.column_pos_or_err(field)?; + let col_ty = &head.fields[col_id.idx()].algebraic_type; + match col_ty { + &AlgebraicType::Bool => Ok(()), + ty => Err(RelationError::NotBoolType { field, ty: ty.clone() }), + } + } + } + } + + /// Traverses `op`, checking any logical operators for bool-typed operands. + fn check_field_op_logics(&self, op: &FieldOp) -> Result<(), RelationError> { + use OpQuery::*; + match op { + FieldOp::Field(_) => Ok(()), + FieldOp::Cmp { op: Cmp(_), lhs, rhs } => { + self.check_field_op_logics(lhs)?; + self.check_field_op_logics(rhs)?; + Ok(()) + } + FieldOp::Cmp { op: Logic(_), lhs, rhs } => { + self.check_field_op(lhs)?; + self.check_field_op(rhs)?; + Ok(()) } } } - pub fn with_select_cmp(self, op: O, lhs: LHS, rhs: RHS) -> Self + pub fn with_select_cmp(self, op: O, lhs: LHS, rhs: RHS) -> Result where LHS: Into, RHS: Into, O: Into, { - let op = ColumnOp::new(op.into(), ColumnOp::Field(lhs.into()), ColumnOp::Field(rhs.into())); + let op = FieldOp::new(op.into(), FieldOp::Field(lhs.into()), FieldOp::Field(rhs.into())); self.with_select(op) } // Appends a project operation to the query operator pipeline. // The `wildcard_table_id` represents a projection of the form `table.*`. // This is used to determine if an inner join can be rewritten as an index join. - pub fn with_project(mut self, cols: &[FieldExpr], wildcard_table: Option) -> Result { - if !cols.is_empty() { - let header_after = Arc::new(self.head().project(cols)?); + pub fn with_project( + mut self, + fields: Vec, + wildcard_table: Option, + ) -> Result { + if !fields.is_empty() { + let header_before = self.head(); + + // Translate the field expressions to column expressions. + let mut cols = Vec::with_capacity(fields.len()); + for field in fields { + cols.push(field.name_to_col(header_before)?); + } + + // Project the header. + // We'll store that so subsequent operations use that as a base. + let header_after = Arc::new(header_before.project(&cols)?); + + // Add the projection. self.query.push(Query::Project(ProjectExpr { - fields: cols.into(), + cols, wildcard_table, header_after, })); @@ -1707,36 +1846,25 @@ impl QueryExpr { let join = query.query.pop().unwrap(); match join { - Query::JoinInner(JoinExpr { - rhs: probe_side, - col_lhs: index_col, - col_rhs: probe_col, - inner: None, - }) => { - if !probe_side.query.is_empty() { + Query::JoinInner(join @ JoinExpr { inner: None, .. }) => { + if !join.rhs.query.is_empty() { // An applicable join must have an index defined on the correct field. - if source.head().has_constraint(index_col, Constraints::indexed()) { + if source.head().has_constraint(join.col_lhs, Constraints::indexed()) { let index_join = IndexJoin { - probe_side, - probe_col, + probe_side: join.rhs, + probe_col: join.col_rhs, index_side: source.clone(), index_select: None, - index_col, + index_col: join.col_lhs, return_index_rows: true, }; let query = [Query::IndexJoin(index_join)].into(); return QueryExpr { source, query }; } } - let join = Query::JoinInner(JoinExpr { - rhs: probe_side, - col_lhs: index_col, - col_rhs: probe_col, - inner: None, - }); QueryExpr { source, - query: vec![join], + query: vec![Query::JoinInner(join)], } } first => QueryExpr { @@ -1754,11 +1882,11 @@ impl QueryExpr { for schema in tables { for op in find_sargable_ops(&mut fields_found, schema.head(), &op) { match &op { - IndexColumnOp::Index(_) | IndexColumnOp::Scan(ColumnOp::Field(_)) => {} + IndexColumnOp::Index(_) | IndexColumnOp::Scan(ColumnOp::Col(_)) => {} // Remove a duplicated/redundant operation on the same `field` and `op` // like `[ScanOrIndex::Index(a = 1), ScanOrIndex::Index(a = 1), ScanOrIndex::Scan(a = 1)]` IndexColumnOp::Scan(ColumnOp::Cmp { op, lhs, rhs: _ }) => { - if let (ColumnOp::Field(FieldExpr::Name(col)), OpQuery::Cmp(cmp)) = (&**lhs, op) { + if let (ColumnOp::Col(ColExpr::Col(col)), OpQuery::Cmp(cmp)) = (&**lhs, op) { if !fields_found.insert((*col, *cmp)) { continue; } @@ -1767,44 +1895,42 @@ impl QueryExpr { } match op { - IndexColumnOp::Index(idx) => match idx { - // Found sargable equality condition for one of the table schemas. - IndexArgument::Eq { columns, value } => { - // `unwrap` here is infallible because `is_sargable(schema, op)` implies `schema.is_db_table` - // for any `op`. - q = q.with_index_eq(schema.get_db_table().unwrap().clone(), columns.clone(), value); - } - // Found sargable range condition for one of the table schemas. - IndexArgument::LowerBound { - columns, - value, - inclusive, - } => { - // `unwrap` here is infallible because `is_sargable(schema, op)` implies `schema.is_db_table` - // for any `op`. - q = q.with_index_lower_bound( - schema.get_db_table().unwrap().clone(), - columns.clone(), + // A sargable condition for on one of the table schemas, + // either an equality or range condition. + IndexColumnOp::Index(idx) => { + let table = schema + .get_db_table() + .expect("find_sargable_ops(schema, op) implies `schema.is_db_table()`") + .clone(); + + q = match idx { + IndexArgument::Eq { columns, value } => q.with_index_eq(table, columns.clone(), value), + IndexArgument::LowerBound { + columns, value, inclusive, - ); - } - // Found sargable range condition for one of the table schemas. - IndexArgument::UpperBound { - columns, - value, - inclusive, - } => { - q = q.with_index_upper_bound( - schema.get_db_table().unwrap().clone(), - columns.clone(), + } => q.with_index_lower_bound(table, columns.clone(), value, inclusive), + IndexArgument::UpperBound { + columns, value, inclusive, - ); - } - }, + } => q.with_index_upper_bound(table, columns.clone(), value, inclusive), + }; + } // Filter condition cannot be answered using an index. - IndexColumnOp::Scan(scan) => q = q.with_select(scan.clone()), + IndexColumnOp::Scan(rhs) => { + let rhs = rhs.clone(); + let op = match q.query.pop() { + // Merge condition into any pre-existing `Select`. + Some(Query::Select(lhs)) => ColumnOp::and(lhs, rhs), + None => rhs, + Some(other) => { + q.query.push(other); + rhs + } + }; + q.query.push(Query::Select(op)); + } } } } @@ -1886,7 +2012,7 @@ impl fmt::Display for Query { write!(f, "select {q}") } Query::Project(proj) => { - let q = &proj.fields; + let q = &proj.cols; write!(f, "project")?; if !q.is_empty() { write!(f, " ")?; @@ -2018,7 +2144,6 @@ mod tests { fields: vec![], constraints: Default::default(), }), - row_count: RowCount::unknown(), table_type: StTableType::User, table_access: StAccess::Private, }, @@ -2113,7 +2238,6 @@ mod tests { SourceExpr::InMemory { source_id: SourceId(0), header: Arc::new(head), - row_count: RowCount::unknown(), table_access, table_type: StTableType::User, } @@ -2134,8 +2258,7 @@ mod tests { let index_col = 1.into(); let probe_col = 1.into(); - let select_field = FieldName::new(index_side.head().table_id, 0.into()); - let index_select = ColumnOp::cmp(select_field, OpCmp::Eq, 0u8); + let index_select = ColumnOp::cmp(0, OpCmp::Eq, 0u8); let join = IndexJoin { probe_side: probe_side.clone().into(), probe_col, @@ -2166,19 +2289,18 @@ mod tests { assert_eq!(join.inner, None); } - fn setup_best_index() -> (Header, [FieldName; 5], [AlgebraicValue; 5]) { + fn setup_best_index() -> (Header, [ColId; 5], [AlgebraicValue; 5]) { let table_id = 0.into(); let vals = [1, 2, 3, 4, 5].map(AlgebraicValue::U64); let col_ids = [0, 1, 2, 3, 4].map(ColId); let [a, b, c, d, _] = col_ids; - let fields = col_ids.map(|c| FieldName::new(table_id, c)); - let cols = fields.map(|f| Column::new(f, AlgebraicType::I8)); + let columns = col_ids.map(|c| Column::new(FieldName::new(table_id, c), AlgebraicType::I8)); let head1 = Header::new( table_id, "t1".into(), - cols.to_vec(), + columns.to_vec(), vec![ // Index a (a.into(), Constraints::primary_key()), @@ -2191,34 +2313,29 @@ mod tests { ], ); - (head1, fields, vals) + (head1, col_ids, vals) } fn make_field_value<'a>( arena: &'a Arena, - (cmp, field, value): (OpCmp, FieldName, &'a AlgebraicValue), - ) -> FieldValue<'a> { - let from_expr = |expr| Box::new(ColumnOp::Field(expr)); + (cmp, col, value): (OpCmp, ColId, &'a AlgebraicValue), + ) -> ColValue<'a> { + let from_expr = |expr| Box::new(ColumnOp::Col(expr)); let op = ColumnOp::Cmp { op: OpQuery::Cmp(cmp), - lhs: from_expr(FieldExpr::Name(field)), - rhs: from_expr(FieldExpr::Value(value.clone())), + lhs: from_expr(ColExpr::Col(col)), + rhs: from_expr(ColExpr::Value(value.clone())), }; let parent = arena.alloc(op); - FieldValue::new(parent, cmp, field, value) + ColValue::new(parent, col, cmp, value) } - fn scan_eq<'a>(arena: &'a Arena, field: FieldName, val: &'a AlgebraicValue) -> IndexColumnOp<'a> { - scan(arena, OpCmp::Eq, field, val) + fn scan_eq<'a>(arena: &'a Arena, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> { + scan(arena, OpCmp::Eq, col, val) } - fn scan<'a>( - arena: &'a Arena, - cmp: OpCmp, - field: FieldName, - val: &'a AlgebraicValue, - ) -> IndexColumnOp<'a> { - IndexColumnOp::Scan(make_field_value(arena, (cmp, field, val)).parent) + fn scan<'a>(arena: &'a Arena, cmp: OpCmp, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> { + IndexColumnOp::Scan(make_field_value(arena, (cmp, col, val)).parent) } #[test] @@ -2232,7 +2349,7 @@ mod tests { let fields = fields .iter() .copied() - .map(|(col, val): (FieldName, _)| make_field_value(&arena, (OpCmp::Eq, col, val)).parent) + .map(|(col, val): (ColId, _)| make_field_value(&arena, (OpCmp::Eq, col, val)).parent) .collect::>(); select_best_index(&mut <_>::default(), &head1, &fields) }; @@ -2248,19 +2365,19 @@ mod tests { assert_eq!( select_best_index(&[(col_a, &val_a)]), - [idx_eq(col_a.col.into(), val_a.clone())].into(), + [idx_eq(col_a.into(), val_a.clone())].into(), ); assert_eq!( select_best_index(&[(col_b, &val_b)]), - [idx_eq(col_b.col.into(), val_b.clone())].into(), + [idx_eq(col_b.into(), val_b.clone())].into(), ); // Check for permutation assert_eq!( select_best_index(&[(col_b, &val_b), (col_c, &val_c)]), [idx_eq( - col_list![col_b.col, col_c.col], + col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into() )] .into(), @@ -2269,7 +2386,7 @@ mod tests { assert_eq!( select_best_index(&[(col_c, &val_c), (col_b, &val_b)]), [idx_eq( - col_list![col_b.col, col_c.col], + col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into() )] .into(), @@ -2279,7 +2396,7 @@ mod tests { assert_eq!( select_best_index(&[(col_a, &val_a), (col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]), [idx_eq( - col_list![col_a.col, col_b.col, col_c.col, col_d.col], + col_list![col_a, col_b, col_c, col_d], product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(), )] .into(), @@ -2288,7 +2405,7 @@ mod tests { assert_eq!( select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_d, &val_d), (col_c, &val_c)]), [idx_eq( - col_list![col_a.col, col_b.col, col_c.col, col_d.col], + col_list![col_a, col_b, col_c, col_d], product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(), )] .into() @@ -2298,8 +2415,8 @@ mod tests { assert_eq!( select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_e, &val_e), (col_d, &val_d)]), [ - idx_eq(col_a.col.into(), val_a.clone()), - idx_eq(col_b.col.into(), val_b.clone()), + idx_eq(col_a.into(), val_a.clone()), + idx_eq(col_b.into(), val_b.clone()), scan_eq(&arena, col_d, &val_d), scan_eq(&arena, col_e, &val_e), ] @@ -2309,10 +2426,7 @@ mod tests { assert_eq!( select_best_index(&[(col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]), [ - idx_eq( - col_list![col_b.col, col_c.col], - product![val_b.clone(), val_c.clone()].into(), - ), + idx_eq(col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into(),), scan_eq(&arena, col_d, &val_d), ] .into() @@ -2323,12 +2437,12 @@ mod tests { fn best_index_range() { let arena = Arena::new(); - let (head1, fields, vals) = setup_best_index(); - let [col_a, col_b, col_c, col_d, _] = fields; + let (head1, cols, vals) = setup_best_index(); + let [col_a, col_b, col_c, col_d, _] = cols; let [val_a, val_b, val_c, val_d, _] = vals; - let select_best_index = |fields: &[_]| { - let fields = fields + let select_best_index = |cols: &[_]| { + let fields = cols .iter() .map(|x| make_field_value(&arena, *x).parent) .collect::>(); @@ -2336,8 +2450,8 @@ mod tests { }; let col_list_arena = Arena::new(); - let idx = |cmp, cols: &[FieldName], val: &AlgebraicValue| { - let columns = cols.iter().map(|c| c.col).collect::().build().unwrap(); + let idx = |cmp, cols: &[ColId], val: &AlgebraicValue| { + let columns = cols.iter().copied().collect::().build().unwrap(); let columns = col_list_arena.alloc(columns); make_index_arg(cmp, columns, val.clone()) }; @@ -2483,7 +2597,7 @@ mod tests { .with_access(StAccess::Public) .with_type(StTableType::System); // hah! - let crud = CrudExpr::CreateTable { table }; + let crud = CrudExpr::CreateTable { table: Box::new(table) }; assert_owner_required(crud); } @@ -2520,7 +2634,9 @@ mod tests { let q = QueryExpr::new(lhs_source.clone()) .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false) .with_project( - &[0, 1].map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into()))), + [0, 1] + .map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into()))) + .into(), Some(TableId(0)), ) .unwrap(); @@ -2579,14 +2695,14 @@ mod tests { TableId(0), TableDef::new( "lhs".into(), - ProductType::from_iter([AlgebraicType::I32, AlgebraicType::String]).into(), + ProductType::from([AlgebraicType::I32, AlgebraicType::String]).into(), ), ); let rhs = TableSchema::from_def( TableId(1), TableDef::new( "rhs".into(), - ProductType::from_iter([AlgebraicType::I32, AlgebraicType::I64]).into(), + ProductType::from([AlgebraicType::I32, AlgebraicType::I64]).into(), ), ); @@ -2596,7 +2712,9 @@ mod tests { let q = QueryExpr::new(lhs_source.clone()) .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false) .with_project( - &[0, 1].map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into()))), + [0, 1] + .map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into()))) + .into(), Some(TableId(1)), ) .unwrap(); diff --git a/crates/vm/src/iterators.rs b/crates/vm/src/iterators.rs index e7140e7461..9b4ae036fe 100644 --- a/crates/vm/src/iterators.rs +++ b/crates/vm/src/iterators.rs @@ -1,34 +1,21 @@ -use crate::errors::ErrorVm; use crate::rel_ops::RelOps; use crate::relation::RelValue; -use spacetimedb_sats::relation::{Header, RowCount}; -use std::sync::Arc; -/// Turns an iterator over `ProductValue`s into a `RelOps`. +/// Turns an iterator over [`RelValue<'_>`]s into a `RelOps`. #[derive(Debug)] pub struct RelIter { - pub head: Arc
, - pub row_count: RowCount, pub iter: I, } impl RelIter { - pub fn new(head: Arc
, row_count: RowCount, iter: impl IntoIterator) -> Self { + pub fn new(iter: impl IntoIterator) -> Self { let iter = iter.into_iter(); - Self { head, row_count, iter } + Self { iter } } } impl<'a, I: Iterator>> RelOps<'a> for RelIter { - fn head(&self) -> &Arc
{ - &self.head - } - - fn row_count(&self) -> RowCount { - self.row_count - } - - fn next(&mut self) -> Result>, ErrorVm> { - Ok(self.iter.next()) + fn next(&mut self) -> Option> { + self.iter.next() } } diff --git a/crates/vm/src/program.rs b/crates/vm/src/program.rs index b9665088d7..d5166eb8ed 100644 --- a/crates/vm/src/program.rs +++ b/crates/vm/src/program.rs @@ -3,10 +3,7 @@ //! It carries an [EnvDb] with the functions, idents, types. use crate::errors::ErrorVm; -use crate::eval::{build_query, build_source_expr_query}; use crate::expr::{Code, CrudExpr, SourceSet}; -use crate::rel_ops::RelOps; -use crate::relation::MemTable; use spacetimedb_sats::ProductValue; /// A trait to allow split the execution of `programs` to allow executing @@ -22,43 +19,3 @@ pub trait ProgramVm { } pub type Sources<'a, const N: usize> = &'a mut SourceSet, N>; - -/// A default program that run in-memory without a database -pub struct Program; - -impl ProgramVm for Program { - fn eval_query(&mut self, query: CrudExpr, sources: Sources<'_, N>) -> Result { - match query { - CrudExpr::Query(query) => { - let result = build_source_expr_query(sources, &query.source); - let result = build_query(result, &query.query, sources)?; - - let head = result.head().clone(); - let rows: Vec<_> = result.collect_vec(|row| row.into_product_value())?; - - Ok(Code::Table(MemTable::new(head, query.source.table_access(), rows))) - } - CrudExpr::Insert { .. } => { - todo!() - } - CrudExpr::Update { .. } => { - todo!() - } - CrudExpr::Delete { .. } => { - todo!() - } - CrudExpr::CreateTable { .. } => { - todo!() - } - CrudExpr::Drop { .. } => { - todo!() - } - CrudExpr::SetVar { .. } => { - todo!() - } - CrudExpr::ReadVar { .. } => { - todo!() - } - } - } -} diff --git a/crates/vm/src/rel_ops.rs b/crates/vm/src/rel_ops.rs index 5f8cee1da0..c5209646c9 100644 --- a/crates/vm/src/rel_ops.rs +++ b/crates/vm/src/rel_ops.rs @@ -1,18 +1,12 @@ -use crate::errors::ErrorVm; use crate::relation::RelValue; use spacetimedb_data_structures::map::HashMap; -use spacetimedb_sats::relation::{FieldExpr, Header, RowCount}; +use spacetimedb_sats::relation::ColExpr; use spacetimedb_sats::AlgebraicValue; -use std::sync::Arc; /// A trait for dealing with fallible iterators for the database. pub trait RelOps<'a> { - fn head(&self) -> &Arc
; - fn row_count(&self) -> RowCount { - RowCount::unknown() - } /// Advances the `iterator` and returns the next [RelValue]. - fn next(&mut self) -> Result>, ErrorVm>; + fn next(&mut self) -> Option>; /// Creates an `Iterator` which uses a closure to determine if a [RelValueRef] should be yielded. /// @@ -25,7 +19,7 @@ pub trait RelOps<'a> { #[inline] fn select

(self, predicate: P) -> Select where - P: FnMut(&RelValue<'_>) -> Result, + P: FnMut(&RelValue<'_>) -> bool, Self: Sized, { Select::new(self, predicate) @@ -41,13 +35,12 @@ pub trait RelOps<'a> { /// /// It is the equivalent of a `SELECT` clause on SQL. #[inline] - fn project<'b, P>(self, after_head: &'b Arc

, cols: &'b [FieldExpr], extractor: P) -> Project<'b, Self, P> + fn project<'b, P>(self, cols: &'b [ColExpr], extractor: P) -> Project<'b, Self, P> where - P: for<'c> FnMut(&[FieldExpr], RelValue<'c>) -> Result, ErrorVm>, + P: for<'c> FnMut(&[ColExpr], RelValue<'c>) -> RelValue<'c>, Self: Sized, { - let count = self.row_count(); - Project::new(self, count, after_head, cols, extractor) + Project::new(self, cols, extractor) } /// Intersection between the left and the right, both (non-sorted) `iterators`. @@ -64,7 +57,6 @@ pub trait RelOps<'a> { fn join_inner( self, with: Rhs, - head: Arc
, key_lhs: KeyLhs, key_rhs: KeyRhs, predicate: Pred, @@ -78,37 +70,25 @@ pub trait RelOps<'a> { KeyRhs: FnMut(&RelValue<'a>) -> AlgebraicValue, Rhs: RelOps<'a>, { - JoinInner::new(head, self, with, key_lhs, key_rhs, predicate, project) + JoinInner::new(self, with, key_lhs, key_rhs, predicate, project) } /// Collect all the rows in this relation into a `Vec` given a function `RelValue<'a> -> T`. #[inline] - fn collect_vec(mut self, mut convert: impl FnMut(RelValue<'a>) -> T) -> Result, ErrorVm> + fn collect_vec(mut self, mut convert: impl FnMut(RelValue<'a>) -> T) -> Vec where Self: Sized, { - let count = self.row_count(); - let estimate = count.max.unwrap_or(count.min); - let mut result = Vec::with_capacity(estimate); - - while let Some(row) = self.next()? { + let mut result = Vec::new(); + while let Some(row) = self.next() { result.push(convert(row)); } - - Ok(result) + result } } impl<'a, I: RelOps<'a> + ?Sized> RelOps<'a> for Box { - fn head(&self) -> &Arc
{ - (**self).head() - } - - fn row_count(&self) -> RowCount { - (**self).row_count() - } - - fn next(&mut self) -> Result>, ErrorVm> { + fn next(&mut self) -> Option> { (**self).next() } } @@ -117,23 +97,11 @@ impl<'a, I: RelOps<'a> + ?Sized> RelOps<'a> for Box { /// /// Used to compile queries with unsatisfiable bounds, like `WHERE x < 5 AND x > 5`. #[derive(Clone, Debug)] -pub struct EmptyRelOps { - head: Arc
, -} - -impl EmptyRelOps { - pub fn new(head: Arc
) -> Self { - Self { head } - } -} +pub struct EmptyRelOps; impl<'a> RelOps<'a> for EmptyRelOps { - fn head(&self) -> &Arc
{ - &self.head - } - - fn next(&mut self) -> Result>, ErrorVm> { - Ok(None) + fn next(&mut self) -> Option> { + None } } @@ -152,75 +120,44 @@ impl Select { impl<'a, I, P> RelOps<'a> for Select where I: RelOps<'a>, - P: FnMut(&RelValue<'a>) -> Result, + P: FnMut(&RelValue<'a>) -> bool, { - fn head(&self) -> &Arc
{ - self.iter.head() - } - - fn next(&mut self) -> Result>, ErrorVm> { + fn next(&mut self) -> Option> { let filter = &mut self.predicate; - while let Some(v) = self.iter.next()? { - if filter(&v)? { - return Ok(Some(v)); + while let Some(v) = self.iter.next() { + if filter(&v) { + return Some(v); } } - Ok(None) + None } } #[derive(Clone, Debug)] pub struct Project<'a, I, P> { - pub(crate) head: &'a Arc
, - pub(crate) count: RowCount, - pub(crate) cols: &'a [FieldExpr], + pub(crate) cols: &'a [ColExpr], pub(crate) iter: I, pub(crate) extractor: P, } impl<'a, I, P> Project<'a, I, P> { - pub fn new( - iter: I, - count: RowCount, - head: &'a Arc
, - cols: &'a [FieldExpr], - extractor: P, - ) -> Project<'a, I, P> { - Project { - iter, - count, - cols, - extractor, - head, - } + pub fn new(iter: I, cols: &'a [ColExpr], extractor: P) -> Project<'a, I, P> { + Project { iter, cols, extractor } } } impl<'a, I, P> RelOps<'a> for Project<'_, I, P> where I: RelOps<'a>, - P: FnMut(&[FieldExpr], RelValue<'a>) -> Result, ErrorVm>, + P: FnMut(&[ColExpr], RelValue<'a>) -> RelValue<'a>, { - fn head(&self) -> &Arc
{ - self.head - } - - fn row_count(&self) -> RowCount { - self.count - } - - fn next(&mut self) -> Result>, ErrorVm> { - let extract = &mut self.extractor; - if let Some(v) = self.iter.next()? { - return Ok(Some(extract(self.cols, v)?)); - } - Ok(None) + fn next(&mut self) -> Option> { + self.iter.next().map(|v| (self.extractor)(self.cols, v)) } } #[derive(Clone, Debug)] pub struct JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { - pub(crate) head: Arc
, pub(crate) lhs: Lhs, pub(crate) rhs: Rhs, pub(crate) key_lhs: KeyLhs, @@ -233,17 +170,8 @@ pub struct JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { } impl<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { - pub fn new( - head: Arc
, - lhs: Lhs, - rhs: Rhs, - key_lhs: KeyLhs, - key_rhs: KeyRhs, - predicate: Pred, - projection: Proj, - ) -> Self { + pub fn new(lhs: Lhs, rhs: Rhs, key_lhs: KeyLhs, key_rhs: KeyRhs, predicate: Pred, projection: Proj) -> Self { Self { - head, map: HashMap::new(), lhs, rhs, @@ -266,15 +194,11 @@ where Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> bool, Proj: FnMut(RelValue<'a>, RelValue<'a>) -> RelValue<'a>, { - fn head(&self) -> &Arc
{ - &self.head - } - - fn next(&mut self) -> Result>, ErrorVm> { + fn next(&mut self) -> Option> { // Consume `Rhs`, building a map `KeyRhs => Rhs`. if !self.filled_rhs { - self.map = HashMap::with_capacity(self.rhs.row_count().min); - while let Some(row_rhs) = self.rhs.next()? { + self.map = HashMap::new(); + while let Some(row_rhs) = self.rhs.next() { let key_rhs = (self.key_rhs)(&row_rhs); self.map.entry(key_rhs).or_default().push(row_rhs); } @@ -285,10 +209,7 @@ where // Consume a row in `Lhs` and project to `KeyLhs`. let lhs = match &self.left { Some(left) => left, - None => match self.lhs.next()? { - Some(x) => self.left.insert(x), - None => return Ok(None), - }, + None => self.left.insert(self.lhs.next()?), }; let k = (self.key_lhs)(lhs); @@ -297,7 +218,7 @@ where if let Some(rvv) = self.map.get_mut(&k) { if let Some(rhs) = rvv.pop() { if (self.predicate)(lhs, &rhs) { - return Ok(Some((self.projection)(lhs.clone(), rhs))); + return Some((self.projection)(lhs.clone(), rhs)); } } } diff --git a/crates/vm/src/relation.rs b/crates/vm/src/relation.rs index ee417708b5..db8d5626e8 100644 --- a/crates/vm/src/relation.rs +++ b/crates/vm/src/relation.rs @@ -1,9 +1,8 @@ use core::hash::{Hash, Hasher}; use spacetimedb_sats::bsatn::ser::BsatnError; use spacetimedb_sats::db::auth::StAccess; -use spacetimedb_sats::db::error::RelationError; use spacetimedb_sats::product_value::ProductValue; -use spacetimedb_sats::relation::{FieldExpr, FieldExprRef, FieldName, Header, Relation, RowCount}; +use spacetimedb_sats::relation::{ColExpr, ColExprRef, Header}; use spacetimedb_sats::{bsatn, impl_serialize, AlgebraicValue}; use spacetimedb_table::read_column::ReadColumn; use spacetimedb_table::table::RowRef; @@ -14,7 +13,7 @@ use std::sync::Arc; /// a reference to an inserted row, /// or an ephemeral row constructed during query execution. /// -/// A `RelValue` is the type generated/consumed by a [Relation] operator. +/// A `RelValue` is the type generated/consumed by queries. #[derive(Debug, Clone)] pub enum RelValue<'a> { /// A reference to a row in a table. @@ -116,34 +115,22 @@ impl<'a> RelValue<'a> { } } - pub fn get<'b>( - &'a self, - col: FieldExprRef<'a>, - header: &'b Header, - ) -> Result, RelationError> { - let val = match col { - FieldExprRef::Name(col) => { - let pos = header.column_pos_or_err(col)?.idx(); - self.read_column(pos) - .ok_or_else(|| RelationError::FieldNotFoundAtPos(pos, col))? - } - FieldExprRef::Value(x) => Cow::Borrowed(x), - }; - - Ok(val) - } - - pub fn project(&self, cols: &[FieldExprRef<'_>], header: &'a Header) -> Result { - let mut elements = Vec::with_capacity(cols.len()); - for col in cols { - elements.push(self.get(*col, header)?.into_owned()); + /// Returns a column either at the index specified in `col`, + /// or the column is the value that `col` holds. + /// + /// Panics if, for `ColExprRef::Col(col)`, the `col` is out of bounds of `self`. + pub fn get(&'a self, col: ColExprRef<'a>) -> Cow<'a, AlgebraicValue> { + match col { + ColExprRef::Col(col) => self.read_column(col.idx()).unwrap(), + ColExprRef::Value(x) => Cow::Borrowed(x), } - Ok(elements.into()) } /// Reads or takes the column at `col`. /// Calling this method consumes the column at `col` for a `RelValue::Projection`, /// so it should not be called again for the same input. + /// + /// Panics if `col` is out of bounds of `self`. pub fn read_or_take_column(&mut self, col: usize) -> Option { match self { Self::Row(row_ref) => AlgebraicValue::read_column(*row_ref, col).ok(), @@ -152,20 +139,17 @@ impl<'a> RelValue<'a> { } } - pub fn project_owned(mut self, cols: &[FieldExpr], header: &Header) -> Result { - let mut elements = Vec::with_capacity(cols.len()); - for col in cols { - let val = match col { - FieldExpr::Name(col) => { - let pos = header.column_pos_or_err(*col)?.idx(); - self.read_or_take_column(pos) - .ok_or_else(|| RelationError::FieldNotFoundAtPos(pos, *col))? - } - FieldExpr::Value(x) => x.clone(), - }; - elements.push(val); - } - Ok(elements.into()) + /// Turns `cols` into a product + /// where a value in `cols` is taken directly from it and indices are taken from `self`. + /// + /// Panics on an index that is out of bounds of `self`. + pub fn project_owned(mut self, cols: &[ColExpr]) -> ProductValue { + cols.iter() + .map(|col| match col { + ColExpr::Col(col) => self.read_or_take_column(col.idx()).unwrap(), + ColExpr::Value(x) => x.clone(), + }) + .collect() } /// BSATN-encode the row referred to by `self` into `buf`, @@ -214,18 +198,4 @@ impl MemTable { table_access: StAccess::Public, } } - - pub fn get_field_pos(&self, pos: usize) -> Option<&FieldName> { - self.head.fields.get(pos).map(|x| &x.field) - } -} - -impl Relation for MemTable { - fn head(&self) -> &Arc
{ - &self.head - } - - fn row_count(&self) -> RowCount { - RowCount::exact(self.data.len()) - } }