From df5b78a2ffe823ce1f0fca4be0a784233512e911 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Fri, 11 Oct 2024 10:54:07 -0700 Subject: [PATCH] refactor: Add TableId to relvar type (#1803) --- Cargo.lock | 30 +- Cargo.toml | 4 +- crates/core/Cargo.toml | 2 +- crates/core/src/error.rs | 2 +- crates/core/src/sql/ast.rs | 6 +- .../subscription/module_subscription_actor.rs | 11 +- crates/{planner => expr}/Cargo.toml | 4 +- crates/{planner => expr}/LICENSE | 0 .../src/logical/bind.rs => expr/src/check.rs} | 63 +- .../src/logical => expr/src}/errors.rs | 2 +- .../{planner/src/logical => expr/src}/expr.rs | 5 +- .../src/logical/mod.rs => expr/src/lib.rs} | 97 ++- .../logical/stmt.rs => expr/src/statement.rs} | 109 +-- crates/expr/src/ty.rs | 639 ++++++++++++++++++ crates/planner/src/lib.rs | 1 - crates/planner/src/logical/ty.rs | 340 ---------- crates/sql-parser/src/ast/mod.rs | 34 +- crates/sql-parser/src/parser/mod.rs | 6 +- 18 files changed, 823 insertions(+), 532 deletions(-) rename crates/{planner => expr}/Cargo.toml (74%) rename crates/{planner => expr}/LICENSE (100%) rename crates/{planner/src/logical/bind.rs => expr/src/check.rs} (85%) rename crates/{planner/src/logical => expr/src}/errors.rs (99%) rename crates/{planner/src/logical => expr/src}/expr.rs (96%) rename crates/{planner/src/logical/mod.rs => expr/src/lib.rs} (82%) rename crates/{planner/src/logical/stmt.rs => expr/src/statement.rs} (77%) create mode 100644 crates/expr/src/ty.rs delete mode 100644 crates/planner/src/lib.rs delete mode 100644 crates/planner/src/logical/ty.rs diff --git a/Cargo.lock b/Cargo.lock index 1a8299ab73..d63d715467 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4524,10 +4524,10 @@ dependencies = [ "spacetimedb-commitlog", "spacetimedb-data-structures", "spacetimedb-durability", + "spacetimedb-expr", "spacetimedb-lib", "spacetimedb-metrics", "spacetimedb-primitives", - "spacetimedb-query-planner", "spacetimedb-sats", "spacetimedb-schema", "spacetimedb-snapshot", @@ -4579,6 +4579,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "spacetimedb-expr" +version = "1.0.0-rc1" +dependencies = [ + "derive_more", + "spacetimedb-lib", + "spacetimedb-primitives", + "spacetimedb-sats", + "spacetimedb-schema", + "spacetimedb-sql-parser", + "string-interner", + "thiserror", +] + [[package]] name = "spacetimedb-fs-utils" version = "1.0.0-rc1" @@ -4637,20 +4651,6 @@ dependencies = [ "proptest", ] -[[package]] -name = "spacetimedb-query-planner" -version = "1.0.0-rc1" -dependencies = [ - "derive_more", - "spacetimedb-lib", - "spacetimedb-primitives", - "spacetimedb-sats", - "spacetimedb-schema", - "spacetimedb-sql-parser", - "string-interner", - "thiserror", -] - [[package]] name = "spacetimedb-quickstart-module" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 55aa4a2ebc..2de5800e6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,9 @@ members = [ "crates/core", "crates/data-structures", "crates/durability", + "crates/expr", "crates/fs-utils", "crates/lib", - "crates/planner", "crates/metrics", "crates/primitives", "crates/sats", @@ -96,10 +96,10 @@ spacetimedb-commitlog = { path = "crates/commitlog", version = "1.0.0-rc1" } spacetimedb-core = { path = "crates/core", version = "1.0.0-rc1" } spacetimedb-data-structures = { path = "crates/data-structures", version = "1.0.0-rc1" } spacetimedb-durability = { path = "crates/durability", version = "1.0.0-rc1" } +spacetimedb-expr = { path = "crates/expr", version = "1.0.0-rc1" } spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1.0.0-rc1" } spacetimedb-metrics = { path = "crates/metrics", version = "1.0.0-rc1" } spacetimedb-primitives = { path = "crates/primitives", version = "1.0.0-rc1" } -spacetimedb-query-planner = { path = "crates/planner", version = "1.0.0-rc1" } spacetimedb-sats = { path = "crates/sats", version = "1.0.0-rc1" } spacetimedb-schema = { path = "crates/schema", version = "1.0.0-rc1" } spacetimedb-standalone = { path = "crates/standalone", version = "1.0.0-rc1" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index d7cfdce5ea..a124de55d6 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -27,7 +27,7 @@ spacetimedb-schema.workspace = true spacetimedb-table.workspace = true spacetimedb-vm.workspace = true spacetimedb-snapshot.workspace = true -spacetimedb-query-planner.workspace = true +spacetimedb-expr.workspace = true anyhow = { workspace = true, features = ["backtrace"] } arrayvec.workspace = true diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 8721bed5c8..bc5661ef89 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::sync::{MutexGuard, PoisonError}; use hex::FromHexError; -use spacetimedb_query_planner::logical::errors::TypingError; +use spacetimedb_expr::errors::TypingError; use spacetimedb_sats::AlgebraicType; use spacetimedb_schema::error::ValidationErrors; use spacetimedb_snapshot::SnapshotError; diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 8f366ff10a..9ad4963986 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -1,13 +1,13 @@ use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::{DBError, PlanError}; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; +use spacetimedb_expr::check::SchemaView; +use spacetimedb_expr::statement::parse_and_type_sql; use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::db::error::RelationError; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::relation::{ColExpr, FieldName}; use spacetimedb_primitives::ColId; -use spacetimedb_query_planner::logical::bind::SchemaView; -use spacetimedb_query_planner::logical::stmt::parse_and_type_sql; use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; use spacetimedb_schema::schema::{ColumnSchema, TableSchema}; use spacetimedb_vm::errors::ErrorVm; @@ -480,7 +480,7 @@ pub struct SchemaViewer<'a, T> { } impl SchemaView for SchemaViewer<'_, T> { - fn schema(&self, name: &str, _: bool) -> Option> { + fn schema(&self, name: &str) -> Option> { let name = name.to_owned().into_boxed_str(); let schema = self .tx diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 5bf97bd8e1..f77272b92f 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -15,9 +15,10 @@ use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use parking_lot::RwLock; use spacetimedb_client_api_messages::websocket::FormatSwitch; +use spacetimedb_expr::check::parse_and_type_sub; +use spacetimedb_expr::ty::TyCtx; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Identity; -use spacetimedb_query_planner::logical::bind::parse_and_type_sub; use spacetimedb_vm::errors::ErrorVm; use spacetimedb_vm::expr::AuthAccess; use std::time::Duration; @@ -88,7 +89,11 @@ impl ModuleSubscriptions { } else { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - parse_and_type_sub(sql, &SchemaViewer::new(&self.relational_db, &*tx, &auth))?; + parse_and_type_sub( + &mut TyCtx::default(), + sql, + &SchemaViewer::new(&self.relational_db, &*tx, &auth), + )?; let mut compiled = compile_read_only_query(&self.relational_db, &auth, &tx, sql)?; // Note that no error path is needed here. @@ -249,9 +254,9 @@ mod tests { use crate::error::DBError; use crate::execution_context::ExecutionContext; use spacetimedb_client_api_messages::websocket::Subscribe; + use spacetimedb_expr::errors::{TypingError, Unresolved}; use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::{error::ResultTest, AlgebraicType, Identity}; - use spacetimedb_query_planner::logical::errors::{TypingError, Unresolved}; use spacetimedb_sats::product; use std::time::Instant; use std::{sync::Arc, time::Duration}; diff --git a/crates/planner/Cargo.toml b/crates/expr/Cargo.toml similarity index 74% rename from crates/planner/Cargo.toml rename to crates/expr/Cargo.toml index 745055d516..43e9cec021 100644 --- a/crates/planner/Cargo.toml +++ b/crates/expr/Cargo.toml @@ -1,8 +1,10 @@ [package] -name = "spacetimedb-query-planner" +name = "spacetimedb-expr" version.workspace = true edition.workspace = true +rust-version.workspace = true license-file = "LICENSE" +description = "The logical expression representation for the SpacetimeDB query engine" [dependencies] derive_more.workspace = true diff --git a/crates/planner/LICENSE b/crates/expr/LICENSE similarity index 100% rename from crates/planner/LICENSE rename to crates/expr/LICENSE diff --git a/crates/planner/src/logical/bind.rs b/crates/expr/src/check.rs similarity index 85% rename from crates/planner/src/logical/bind.rs rename to crates/expr/src/check.rs index 4336e0f41a..72d2e382a9 100644 --- a/crates/planner/src/logical/bind.rs +++ b/crates/expr/src/check.rs @@ -5,16 +5,18 @@ use spacetimedb_sql_parser::{ ast::{ self, sub::{SqlAst, SqlSelect}, - SqlFrom, + SqlFrom, SqlIdent, SqlJoin, }, parser::sub::parse_subscription, }; +use crate::ty::TyId; + use super::{ assert_eq_types, errors::{DuplicateName, TypingError, Unresolved, Unsupported}, expr::{Expr, Let, RelExpr}, - ty::{Symbol, TyCtx, TyEnv, TyId, Type}, + ty::{Symbol, TyCtx, TyEnv}, type_expr, type_proj, type_select, }; @@ -22,7 +24,7 @@ use super::{ pub type TypingResult = core::result::Result; pub trait SchemaView { - fn schema(&self, name: &str, case_sensitive: bool) -> Option>; + fn schema(&self, name: &str) -> Option>; } pub trait TypeChecker { @@ -40,12 +42,12 @@ pub trait TypeChecker { ) -> TypingResult<(RelExpr, Option)> { match from { SqlFrom::Expr(expr, None) => Self::type_rel(ctx, expr, tx), - SqlFrom::Expr(expr, Some(alias)) => { + SqlFrom::Expr(expr, Some(SqlIdent(alias))) => { let (expr, _) = Self::type_rel(ctx, expr, tx)?; - let symbol = ctx.gen_symbol(alias.name); + let symbol = ctx.gen_symbol(alias); Ok((expr, Some(symbol))) } - SqlFrom::Join(r, alias, joins) => { + SqlFrom::Join(r, SqlIdent(alias), joins) => { // The type environment with which to type the join expressions let mut env = TyEnv::default(); // The lowered inputs to the join operator @@ -57,33 +59,37 @@ pub trait TypeChecker { let input = Self::type_rel(ctx, r, tx)?.0; let ty = input.ty_id(); - let name = ctx.gen_symbol(alias.name); + let name = ctx.gen_symbol(alias); env.add(name, ty); inputs.push(input); types.push((name, ty)); - for join in joins { - let input = Self::type_rel(ctx, join.expr, tx)?.0; + for SqlJoin { + expr, + alias: SqlIdent(alias), + on, + } in joins + { + let input = Self::type_rel(ctx, expr, tx)?.0; let ty = input.ty_id(); - let name = ctx.gen_symbol(&join.alias.name); + let name = ctx.gen_symbol(&alias); // New join variable is now in scope if env.add(name, ty).is_some() { - return Err(DuplicateName(join.alias.name).into()); + return Err(DuplicateName(alias.into_string()).into()); } inputs.push(input); types.push((name, ty)); // Type check join expression with current type environment - if let Some(on) = join.on { + if let Some(on) = on { exprs.push(type_expr(ctx, &env, on, Some(TyId::BOOL))?); } } - let ty = Type::Row(types.clone().into_boxed_slice()); - let ty = ctx.add(ty); + let ty = ctx.add_row_type(types.clone()); let input = RelExpr::Join(inputs.into(), ty); let vars = types .into_iter() @@ -101,21 +107,19 @@ pub trait TypeChecker { tx: &impl SchemaView, ) -> TypingResult<(RelExpr, Option)> { match expr { - ast::RelExpr::Var(var) => { + ast::RelExpr::Var(SqlIdent(var)) => { let schema = tx - .schema(&var.name, var.case_sensitive) - .ok_or_else(|| Unresolved::table(&var.name)) + .schema(&var) + .ok_or_else(|| Unresolved::table(&var)) .map_err(TypingError::from)?; let mut types = Vec::new(); for ColumnSchema { col_name, col_type, .. } in schema.columns() { - let ty = Type::Alg(col_type.clone()); - let id = ctx.add(ty); + let id = ctx.add_algebraic_type(col_type); let name = ctx.gen_symbol(col_name); types.push((name, id)); } - let ty = Type::Var(types.into_boxed_slice()); - let id = ctx.add(ty); - let symbol = ctx.gen_symbol(var.name); + let id = ctx.add_var_type(schema.table_id, types); + let symbol = ctx.gen_symbol(var); Ok((RelExpr::RelVar(schema, id), Some(symbol))) } ast::RelExpr::Ast(ast) => Ok((Self::type_ast(ctx, *ast, tx)?, None)), @@ -170,10 +174,9 @@ impl TypeChecker for SubChecker { } /// Parse and type check a subscription query -pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult { - let mut ctx = TyCtx::default(); - let expr = SubChecker::type_ast(&mut ctx, parse_subscription(sql)?, tx)?; - expect_table_type(&ctx, expr) +pub fn parse_and_type_sub(ctx: &mut TyCtx, sql: &str, tx: &impl SchemaView) -> TypingResult { + let expr = SubChecker::type_ast(ctx, parse_subscription(sql)?, tx)?; + expect_table_type(ctx, expr) } /// Returns an error if the input type is not a table type or relvar @@ -192,6 +195,8 @@ mod tests { }; use std::sync::Arc; + use crate::ty::TyCtx; + use super::{parse_and_type_sub, SchemaView}; fn module_def() -> ModuleDef { @@ -222,7 +227,7 @@ mod tests { struct SchemaViewer(ModuleDef); impl SchemaView for SchemaViewer { - fn schema(&self, name: &str, _: bool) -> Option> { + fn schema(&self, name: &str) -> Option> { self.0.table(name).map(|def| { Arc::new(TableSchema::from_module_def( &self.0, @@ -253,7 +258,7 @@ mod tests { "select * from (select t.* from t join (select u32 as a from s) s on t.u32 = s.a)", "select * from (select * from t union all select * from t)", ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(&mut TyCtx::default(), sql, &tx); assert!(result.is_ok()); } } @@ -294,7 +299,7 @@ mod tests { // Union arguments are of different types "select * from (select * from t union all select * from s)", ] { - let result = parse_and_type_sub(sql, &tx); + let result = parse_and_type_sub(&mut TyCtx::default(), sql, &tx); assert!(result.is_err()); } } diff --git a/crates/planner/src/logical/errors.rs b/crates/expr/src/errors.rs similarity index 99% rename from crates/planner/src/logical/errors.rs rename to crates/expr/src/errors.rs index aeb570343a..44ceb4eb8c 100644 --- a/crates/planner/src/logical/errors.rs +++ b/crates/expr/src/errors.rs @@ -2,7 +2,7 @@ use spacetimedb_sql_parser::{ast::BinOp, parser::errors::SqlParseError}; use thiserror::Error; use super::{ - stmt::InvalidVar, + statement::InvalidVar, ty::{InvalidTypeId, TypeWithCtx}, }; diff --git a/crates/planner/src/logical/expr.rs b/crates/expr/src/expr.rs similarity index 96% rename from crates/planner/src/logical/expr.rs rename to crates/expr/src/expr.rs index 1564b4041d..42b7feb68a 100644 --- a/crates/planner/src/logical/expr.rs +++ b/crates/expr/src/expr.rs @@ -112,9 +112,8 @@ impl Expr { } /// Returns a string literal - pub fn str(v: String) -> Self { - let s = v.into_boxed_str(); - Self::Lit(AlgebraicValue::String(s), TyId::STR) + pub fn str(v: Box) -> Self { + Self::Lit(AlgebraicValue::String(v), TyId::STR) } /// The type id of this expression diff --git a/crates/planner/src/logical/mod.rs b/crates/expr/src/lib.rs similarity index 82% rename from crates/planner/src/logical/mod.rs rename to crates/expr/src/lib.rs index 5a3c1ecce1..71e8583de1 100644 --- a/crates/planner/src/logical/mod.rs +++ b/crates/expr/src/lib.rs @@ -1,16 +1,16 @@ use std::collections::HashSet; -use bind::TypingResult; +use check::TypingResult; use errors::{DuplicateName, InvalidLiteral, InvalidWildcard, UnexpectedType, Unresolved}; use expr::{Expr, Let, RelExpr}; use spacetimedb_lib::{from_hex_pad, Address, AlgebraicType, AlgebraicValue, Identity}; -use spacetimedb_sql_parser::ast::{self, ProjectElem, ProjectExpr, SqlExpr, SqlLiteral}; +use spacetimedb_sql_parser::ast::{self, ProjectElem, ProjectExpr, SqlExpr, SqlIdent, SqlLiteral}; use ty::{Symbol, TyCtx, TyEnv, TyId, Type, TypeWithCtx}; -pub mod bind; +pub mod check; pub mod errors; pub mod expr; -pub mod stmt; +pub mod statement; pub mod ty; /// Asserts that `$ty` is `$size` bytes in `static_assert_size($ty, $size)`. @@ -71,9 +71,9 @@ pub(crate) fn type_proj( } Ok(input) } - ast::Project::Star(Some(var)) => { + ast::Project::Star(Some(SqlIdent(var))) => { // Get the symbol for this variable - let name = ctx.get_symbol(&var.name).ok_or_else(|| Unresolved::var(&var.name))?; + let name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?; match alias { Some(alias) if alias == name => { @@ -101,7 +101,7 @@ pub(crate) fn type_proj( .ty(ctx)? .expect_relation()? .find(name) - .ok_or_else(|| Unresolved::var(&var.name))?; + .ok_or_else(|| Unresolved::var(&var))?; // Check that * is applied to a row type ctx.try_resolve(ty)? @@ -143,37 +143,37 @@ pub(crate) fn type_proj( for elem in elems { match elem { - ProjectElem(ProjectExpr::Var(field), None) => { - let name = ctx.gen_symbol(&field.name); + ProjectElem(ProjectExpr::Var(SqlIdent(field)), None) => { + let name = ctx.gen_symbol(&field); if !names.insert(name) { - return Err(DuplicateName(field.name).into()); + return Err(DuplicateName(field.into_string()).into()); } - let expr = type_expr(ctx, &tenv, SqlExpr::Var(field), None)?; + let expr = type_expr(ctx, &tenv, SqlExpr::Var(SqlIdent(field)), None)?; field_types.push((name, expr.ty_id())); field_exprs.push((name, expr)); } - ProjectElem(ProjectExpr::Var(field), Some(alias)) => { - let name = ctx.gen_symbol(&alias.name); + ProjectElem(ProjectExpr::Var(field), Some(SqlIdent(alias))) => { + let name = ctx.gen_symbol(&alias); if !names.insert(name) { - return Err(DuplicateName(alias.name).into()); + return Err(DuplicateName(alias.into_string()).into()); } let expr = type_expr(ctx, &tenv, SqlExpr::Var(field), None)?; field_types.push((name, expr.ty_id())); field_exprs.push((name, expr)); } - ProjectElem(ProjectExpr::Field(table, field), None) => { - let name = ctx.gen_symbol(&field.name); + ProjectElem(ProjectExpr::Field(table, SqlIdent(field)), None) => { + let name = ctx.gen_symbol(&field); if !names.insert(name) { - return Err(DuplicateName(field.name).into()); + return Err(DuplicateName(field.into_string()).into()); } - let expr = type_expr(ctx, &tenv, SqlExpr::Field(table, field), None)?; + let expr = type_expr(ctx, &tenv, SqlExpr::Field(table, SqlIdent(field)), None)?; field_types.push((name, expr.ty_id())); field_exprs.push((name, expr)); } - ProjectElem(ProjectExpr::Field(table, field), Some(alias)) => { - let name = ctx.gen_symbol(&alias.name); + ProjectElem(ProjectExpr::Field(table, field), Some(SqlIdent(alias))) => { + let name = ctx.gen_symbol(&alias); if !names.insert(name) { - return Err(DuplicateName(alias.name).into()); + return Err(DuplicateName(alias.into_string()).into()); } let expr = type_expr(ctx, &tenv, SqlExpr::Field(table, field), None)?; field_types.push((name, expr.ty_id())); @@ -184,8 +184,7 @@ pub(crate) fn type_proj( // Column projections produce a new type. // So we must make sure to add it to the typing context. - let ty = Type::Row(field_types.into_boxed_slice()); - let id = ctx.add(ty); + let id = ctx.add_row_type(field_types); Ok(RelExpr::project( input, Let { @@ -203,67 +202,59 @@ pub(crate) fn type_expr(ctx: &TyCtx, vars: &TyEnv, expr: SqlExpr, expected: Opti (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(TyId::BOOL)) => Ok(Expr::bool(v)), (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(id)) => { let expected = ctx.bool(); - let inferred = id.try_with_ctx(ctx)?; + let inferred = ctx.try_resolve(id)?; Err(UnexpectedType::new(&expected, &inferred).into()) } (SqlExpr::Lit(SqlLiteral::Str(v)), None | Some(TyId::STR)) => Ok(Expr::str(v)), (SqlExpr::Lit(SqlLiteral::Str(_)), Some(id)) => { let expected = ctx.str(); - let inferred = id.try_with_ctx(ctx)?; + let inferred = ctx.try_resolve(id)?; Err(UnexpectedType::new(&expected, &inferred).into()) } (SqlExpr::Lit(SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => Err(Unresolved::Literal.into()), (SqlExpr::Lit(SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(id)) => { - let t = id.try_with_ctx(ctx)?; - let v = parse(v, t)?; + let t = ctx.try_resolve(id)?; + let v = parse(v.into_string(), t)?; Ok(Expr::Lit(v, id)) } - (SqlExpr::Var(var), None) => { + (SqlExpr::Var(SqlIdent(var)), None) => { // Is this variable in scope? - let var_name = ctx.get_symbol(&var.name).ok_or_else(|| Unresolved::var(&var.name))?; - let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var.name))?; + let var_name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?; + let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var))?; Ok(Expr::Var(var_name, var_type)) } - (SqlExpr::Var(var), Some(id)) => { + (SqlExpr::Var(SqlIdent(var)), Some(id)) => { // Is this variable in scope? - let var_name = ctx.get_symbol(&var.name).ok_or_else(|| Unresolved::var(&var.name))?; - let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var.name))?; + let var_name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?; + let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var))?; // Is it the correct type? assert_eq_types(ctx, var_type, id)?; Ok(Expr::Var(var_name, var_type)) } - (SqlExpr::Field(table, field), None) => { + (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => { // Is the table variable in scope? - let table_name = ctx - .get_symbol(&table.name) - .ok_or_else(|| Unresolved::var(&table.name))?; - let field_name = ctx - .get_symbol(&field.name) - .ok_or_else(|| Unresolved::var(&field.name))?; - let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table.name))?; + let table_name = ctx.get_symbol(&table).ok_or_else(|| Unresolved::var(&table))?; + let field_name = ctx.get_symbol(&field).ok_or_else(|| Unresolved::var(&field))?; + let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table))?; // Is it a row type, and if so, does it have this field? let (i, field_type) = ctx .try_resolve(table_type)? .expect_relation()? .find(field_name) - .ok_or_else(|| Unresolved::field(&table.name, &field.name))?; + .ok_or_else(|| Unresolved::field(&table, &field))?; Ok(Expr::Field(Box::new(Expr::Var(table_name, table_type)), i, field_type)) } - (SqlExpr::Field(table, field), Some(id)) => { + (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(id)) => { // Is the table variable in scope? - let table_name = ctx - .get_symbol(&table.name) - .ok_or_else(|| Unresolved::var(&table.name))?; - let field_name = ctx - .get_symbol(&field.name) - .ok_or_else(|| Unresolved::var(&field.name))?; - let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table.name))?; + let table_name = ctx.get_symbol(&table).ok_or_else(|| Unresolved::var(&table))?; + let field_name = ctx.get_symbol(&field).ok_or_else(|| Unresolved::var(&field))?; + let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table))?; // Is it a row type, and if so, does it have this field? let (i, field_type) = ctx .try_resolve(table_type)? .expect_relation()? .find(field_name) - .ok_or_else(|| Unresolved::field(&table.name, &field.name))?; + .ok_or_else(|| Unresolved::field(&table, &field))?; // Is the field type correct? assert_eq_types(ctx, field_type, id)?; Ok(Expr::Field(Box::new(Expr::Var(table_name, table_type)), i, field_type)) @@ -280,7 +271,7 @@ pub(crate) fn type_expr(ctx: &TyCtx, vars: &TyEnv, expr: SqlExpr, expected: Opti }, (SqlExpr::Bin(..), Some(id)) => { let expected = ctx.bool(); - let inferred = id.try_with_ctx(ctx)?; + let inferred = ctx.try_resolve(id)?; Err(UnexpectedType::new(&expected, &inferred).into()) } } @@ -289,7 +280,7 @@ pub(crate) fn type_expr(ctx: &TyCtx, vars: &TyEnv, expr: SqlExpr, expected: Opti /// Assert types are structurally equivalent pub(crate) fn assert_eq_types(ctx: &TyCtx, a: TyId, b: TyId) -> TypingResult<()> { if !ctx.eq(a, b)? { - return Err(UnexpectedType::new(&a.try_with_ctx(ctx)?, &b.try_with_ctx(ctx)?).into()); + return Err(UnexpectedType::new(&ctx.try_resolve(a)?, &ctx.try_resolve(b)?).into()); } Ok(()) } diff --git a/crates/planner/src/logical/stmt.rs b/crates/expr/src/statement.rs similarity index 77% rename from crates/planner/src/logical/stmt.rs rename to crates/expr/src/statement.rs index 02a4d565e4..f6142bea04 100644 --- a/crates/planner/src/logical/stmt.rs +++ b/crates/expr/src/statement.rs @@ -12,17 +12,19 @@ use spacetimedb_sql_parser::{ }; use thiserror::Error; +use crate::ty::TyId; + use super::{ assert_eq_types, - bind::{SchemaView, TypeChecker, TypingResult}, + check::{SchemaView, TypeChecker, TypingResult}, errors::{InsertFieldsError, InsertValuesError, TypingError, UnexpectedType, Unresolved, Unsupported}, expr::{Expr, RelExpr}, parse, - ty::{TyCtx, TyEnv, TyId, Type}, + ty::{TyCtx, TyEnv}, type_expr, type_proj, type_select, }; -pub enum Stmt { +pub enum Statement { Select(RelExpr), Insert(TableInsert), Update(TableUpdate), @@ -62,21 +64,21 @@ pub struct ShowVar { /// Type check an INSERT statement pub fn type_insert(ctx: &mut TyCtx, insert: SqlInsert, tx: &impl SchemaView) -> TypingResult { let SqlInsert { - table: SqlIdent { name, case_sensitive }, + table: SqlIdent(table_name), fields, values, } = insert; let schema = tx - .schema(&name, case_sensitive) - .ok_or_else(|| Unresolved::table(&name)) + .schema(&table_name) + .ok_or_else(|| Unresolved::table(&table_name)) .map_err(TypingError::from)?; // Expect n fields let n = schema.columns().len(); if fields.len() != schema.columns().len() { return Err(TypingError::from(InsertFieldsError { - table: name, + table: table_name.into_string(), nfields: fields.len(), ncols: schema.columns().len(), })); @@ -84,7 +86,7 @@ pub fn type_insert(ctx: &mut TyCtx, insert: SqlInsert, tx: &impl SchemaView) -> let mut types = Vec::new(); for ColumnSchema { col_type, .. } in schema.columns() { - let id = ctx.add(Type::Alg(col_type.clone())); + let id = ctx.add_algebraic_type(col_type); types.push(id); } @@ -93,7 +95,7 @@ pub fn type_insert(ctx: &mut TyCtx, insert: SqlInsert, tx: &impl SchemaView) -> // Expect each row to have n values if row.len() != n { return Err(TypingError::from(InsertValuesError { - table: name, + table: table_name.into_string(), values: row.len(), fields: n, })); @@ -105,17 +107,17 @@ pub fn type_insert(ctx: &mut TyCtx, insert: SqlInsert, tx: &impl SchemaView) -> values.push(AlgebraicValue::Bool(v)); } (SqlLiteral::Str(v), TyId::STR) => { - values.push(AlgebraicValue::String(v.into_boxed_str())); + values.push(AlgebraicValue::String(v)); } (SqlLiteral::Bool(_), id) => { - return Err(UnexpectedType::new(&ctx.bool(), &id.try_with_ctx(ctx)?).into()); + return Err(UnexpectedType::new(&ctx.bool(), &ctx.try_resolve(id)?).into()); } (SqlLiteral::Str(_), id) => { - return Err(UnexpectedType::new(&ctx.str(), &id.try_with_ctx(ctx)?).into()); + return Err(UnexpectedType::new(&ctx.str(), &ctx.try_resolve(id)?).into()); } (SqlLiteral::Hex(v), id) | (SqlLiteral::Num(v), id) => { - let ty = id.try_with_ctx(ctx)?; - values.push(parse(v, ty)?); + let ty = ctx.try_resolve(id)?; + values.push(parse(v.into_string(), ty)?); } } } @@ -129,29 +131,27 @@ pub fn type_insert(ctx: &mut TyCtx, insert: SqlInsert, tx: &impl SchemaView) -> /// Type check a DELETE statement pub fn type_delete(ctx: &mut TyCtx, delete: SqlDelete, tx: &impl SchemaView) -> TypingResult { let SqlDelete { - table: SqlIdent { name, case_sensitive }, + table: SqlIdent(table_name), filter, } = delete; let schema = tx - .schema(&name, case_sensitive) - .ok_or_else(|| Unresolved::table(&name)) + .schema(&table_name) + .ok_or_else(|| Unresolved::table(&table_name)) .map_err(TypingError::from)?; - let table_name = ctx.gen_symbol(name); + let table_name = ctx.gen_symbol(table_name); let mut types = Vec::new(); let mut env = TyEnv::default(); for ColumnSchema { col_name, col_type, .. } in schema.columns() { - let ty = Type::Alg(col_type.clone()); - let id = ctx.add(ty); + let id = ctx.add_algebraic_type(col_type); let name = ctx.gen_symbol(col_name); env.add(name, id); types.push((name, id)); } - let ty = Type::Var(types.into_boxed_slice()); - let ty = ctx.add(ty); + let ty = ctx.add_var_type(schema.table_id, types); env.add(table_name, ty); let from = schema; @@ -164,47 +164,46 @@ pub fn type_delete(ctx: &mut TyCtx, delete: SqlDelete, tx: &impl SchemaView) -> /// Type check an UPDATE statement pub fn type_update(ctx: &mut TyCtx, update: SqlUpdate, tx: &impl SchemaView) -> TypingResult { let SqlUpdate { - table, + table: SqlIdent(table_name), assignments, filter, } = update; let schema = tx - .schema(&table.name, table.case_sensitive) - .ok_or_else(|| Unresolved::table(&table.name)) + .schema(&table_name) + .ok_or_else(|| Unresolved::table(&table_name)) .map_err(TypingError::from)?; let mut env = TyEnv::default(); for ColumnSchema { col_name, col_type, .. } in schema.columns() { - let id = ctx.add(Type::Alg(col_type.clone())); + let id = ctx.add_algebraic_type(col_type); let name = ctx.gen_symbol(col_name); env.add(name, id); } let mut values = Vec::new(); - for SqlSet(field, lit) in assignments { + for SqlSet(SqlIdent(field), lit) in assignments { let col_id = schema - .get_column_id_by_name(&field.name) - .ok_or_else(|| Unresolved::field(&table.name, &field.name))?; + .get_column_id_by_name(&field) + .ok_or_else(|| Unresolved::field(&table_name, &field))?; let field_name = ctx - .get_symbol(&field.name) - .ok_or_else(|| Unresolved::field(&table.name, &field.name))?; + .get_symbol(&field) + .ok_or_else(|| Unresolved::field(&table_name, &field))?; let ty = env .find(field_name) - .ok_or_else(|| Unresolved::field(&table.name, &field.name))?; + .ok_or_else(|| Unresolved::field(&table_name, &field))?; match (lit, ty) { (SqlLiteral::Bool(v), TyId::BOOL) => { values.push((col_id, AlgebraicValue::Bool(v))); } (SqlLiteral::Str(v), TyId::STR) => { - values.push((col_id, AlgebraicValue::String(v.into_boxed_str()))); + values.push((col_id, AlgebraicValue::String(v))); } (SqlLiteral::Bool(_), id) => { - return Err(UnexpectedType::new(&ctx.bool(), &id.try_with_ctx(ctx)?).into()); + return Err(UnexpectedType::new(&ctx.bool(), &ctx.try_resolve(id)?).into()); } (SqlLiteral::Str(_), id) => { - return Err(UnexpectedType::new(&ctx.str(), &id.try_with_ctx(ctx)?).into()); + return Err(UnexpectedType::new(&ctx.str(), &ctx.try_resolve(id)?).into()); } (SqlLiteral::Hex(v), id) | (SqlLiteral::Num(v), id) => { - let ty = id.try_with_ctx(ctx)?; - values.push((col_id, parse(v, ty)?)); + values.push((col_id, parse(v.into_string(), ctx.try_resolve(id)?)?)); } } } @@ -231,27 +230,35 @@ fn is_var_valid(var: &str) -> bool { } pub fn type_set(ctx: &TyCtx, set: SqlSet) -> TypingResult { - let SqlSet(SqlIdent { name, .. }, lit) = set; + let SqlSet(SqlIdent(name), lit) = set; if !is_var_valid(&name) { - return Err(InvalidVar { name }.into()); + return Err(InvalidVar { + name: name.into_string(), + } + .into()); } match lit { SqlLiteral::Bool(_) => Err(UnexpectedType::new(&ctx.u64(), &ctx.bool()).into()), SqlLiteral::Str(_) => Err(UnexpectedType::new(&ctx.u64(), &ctx.str()).into()), SqlLiteral::Hex(_) => Err(UnexpectedType::new(&ctx.u64(), &ctx.bytes()).into()), SqlLiteral::Num(n) => Ok(SetVar { - name, - value: parse(n, ctx.u64())?, + name: name.into_string(), + value: parse(n.into_string(), ctx.u64())?, }), } } pub fn type_show(show: SqlShow) -> TypingResult { - let SqlShow(SqlIdent { name, .. }) = show; + let SqlShow(SqlIdent(name)) = show; if !is_var_valid(&name) { - return Err(InvalidVar { name }.into()); + return Err(InvalidVar { + name: name.into_string(), + } + .into()); } - Ok(ShowVar { name }) + Ok(ShowVar { + name: name.into_string(), + }) } /// Type-checker for regular `SQL` queries @@ -341,13 +348,13 @@ impl TypeChecker for SqlChecker { } } -pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { +pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { match parse_sql(sql)? { - SqlAst::Insert(insert) => Ok(Stmt::Insert(type_insert(&mut TyCtx::default(), insert, tx)?)), - SqlAst::Delete(delete) => Ok(Stmt::Delete(type_delete(&mut TyCtx::default(), delete, tx)?)), - SqlAst::Update(update) => Ok(Stmt::Update(type_update(&mut TyCtx::default(), update, tx)?)), - SqlAst::Query(ast) => Ok(Stmt::Select(SqlChecker::type_ast(&mut TyCtx::default(), ast, tx)?)), - SqlAst::Set(set) => Ok(Stmt::Set(type_set(&TyCtx::default(), set)?)), - SqlAst::Show(show) => Ok(Stmt::Show(type_show(show)?)), + SqlAst::Insert(insert) => Ok(Statement::Insert(type_insert(&mut TyCtx::default(), insert, tx)?)), + SqlAst::Delete(delete) => Ok(Statement::Delete(type_delete(&mut TyCtx::default(), delete, tx)?)), + SqlAst::Update(update) => Ok(Statement::Update(type_update(&mut TyCtx::default(), update, tx)?)), + SqlAst::Query(ast) => Ok(Statement::Select(SqlChecker::type_ast(&mut TyCtx::default(), ast, tx)?)), + SqlAst::Set(set) => Ok(Statement::Set(type_set(&TyCtx::default(), set)?)), + SqlAst::Show(show) => Ok(Statement::Show(type_show(show)?)), } } diff --git a/crates/expr/src/ty.rs b/crates/expr/src/ty.rs new file mode 100644 index 0000000000..b7964a59bb --- /dev/null +++ b/crates/expr/src/ty.rs @@ -0,0 +1,639 @@ +use std::{ + collections::HashMap, + fmt::{Display, Formatter}, + ops::Deref, +}; + +use spacetimedb_lib::AlgebraicType; +use spacetimedb_primitives::TableId; +use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; +use spacetimedb_sql_parser::ast::BinOp; +use string_interner::{backend::StringBackend, symbol::SymbolU32, StringInterner}; +use thiserror::Error; + +use super::errors::{ExpectedRelation, InvalidOp}; + +/// When type checking a [super::expr::RelExpr], +/// types are stored in a typing context [TyCtx]. +/// It will then hold references, in the form of [TyId]s, +/// to the types defined in the [TyCtx]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TyId(u32); + +impl TyId { + /// A static type id for Bool + pub const BOOL: Self = Self(0); + + /// A static type id for I8 + pub const I8: Self = Self(1); + + /// A static type id for U8 + pub const U8: Self = Self(2); + + /// A static type id for I16 + pub const I16: Self = Self(3); + + /// A static type id for U16 + pub const U16: Self = Self(4); + + /// A static type id for I32 + pub const I32: Self = Self(5); + + /// A static type id for U32 + pub const U32: Self = Self(6); + + /// A static type id for I64 + pub const I64: Self = Self(7); + + /// A static type id for U64 + pub const U64: Self = Self(8); + + /// A static type id for I128 + pub const I128: Self = Self(9); + + /// A static type id for U128 + pub const U128: Self = Self(10); + + /// A static type id for I256 + pub const I256: Self = Self(11); + + /// A static type id for U256 + pub const U256: Self = Self(12); + + /// A static type id for F32 + pub const F32: Self = Self(13); + + /// A static type id for F64 + pub const F64: Self = Self(14); + + /// A static type id for String + pub const STR: Self = Self(15); + + /// A static type id for a byte array + pub const BYTES: Self = Self(16); + + /// A static type id for [AlgebraicType::identity()] + pub const IDENT: Self = Self(17); + + /// The number of statically defined [TyId]s + const N: usize = 18; +} + +impl Display for TyId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A symbol for names or identifiers in an expression tree +pub type Symbol = SymbolU32; + +/// The type of a relation or scalar expression +#[derive(Debug)] +pub enum Type { + /// A base relation + Var(TableId, Box<[(Symbol, TyId)]>), + /// A derived relation + Row(Box<[(Symbol, TyId)]>), + /// A column type + Alg(AlgebraicType), +} + +impl Type { + /// A constant for the primitive type Bool + pub const BOOL: Self = Self::Alg(AlgebraicType::Bool); + + /// A constant for the primitive type I8 + pub const I8: Self = Self::Alg(AlgebraicType::I8); + + /// A constant for the primitive type U8 + pub const U8: Self = Self::Alg(AlgebraicType::U8); + + /// A constant for the primitive type I16 + pub const I16: Self = Self::Alg(AlgebraicType::I16); + + /// A constant for the primitive type U16 + pub const U16: Self = Self::Alg(AlgebraicType::U16); + + /// A constant for the primitive type I32 + pub const I32: Self = Self::Alg(AlgebraicType::I32); + + /// A constant for the primitive type U32 + pub const U32: Self = Self::Alg(AlgebraicType::U32); + + /// A constant for the primitive type I64 + pub const I64: Self = Self::Alg(AlgebraicType::I64); + + /// A constant for the primitive type U64 + pub const U64: Self = Self::Alg(AlgebraicType::U64); + + /// A constant for the primitive type I128 + pub const I128: Self = Self::Alg(AlgebraicType::I128); + + /// A constant for the primitive type U128 + pub const U128: Self = Self::Alg(AlgebraicType::U128); + + /// A constant for the primitive type I256 + pub const I256: Self = Self::Alg(AlgebraicType::I256); + + /// A constant for the primitive type U256 + pub const U256: Self = Self::Alg(AlgebraicType::U256); + + /// A constant for the primitive type F32 + pub const F32: Self = Self::Alg(AlgebraicType::F32); + + /// A constant for the primitive type F64 + pub const F64: Self = Self::Alg(AlgebraicType::F64); + + /// A constant for the primitive type String + pub const STR: Self = Self::Alg(AlgebraicType::String); + + /// Is this type compatible with this binary operator? + pub fn is_compatible_with(&self, op: BinOp) -> bool { + match (op, self) { + (BinOp::And | BinOp::Or, Type::Alg(AlgebraicType::Bool)) => true, + (BinOp::And | BinOp::Or, _) => false, + (BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, Type::Alg(t)) => { + t.is_bool() + || t.is_integer() + || t.is_float() + || t.is_string() + || t.is_bytes() + || t.is_identity() + || t.is_address() + } + (BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, _) => false, + } + } +} + +/// When type checking a [super::expr::RelExpr], +/// types are stored in a typing context [TyCtx]. +/// It will then hold references, in the form of [TyId]s, +/// to the types defined in the [TyCtx]. +#[derive(Debug)] +pub struct TyCtx { + /// A statically interned byte array type + bytes: Type, + /// A statically interned identity type + ident: Type, + /// Types that are interned dynamically during type checking + types: Vec, + /// Interned identifiers + names: StringInterner, +} + +impl Default for TyCtx { + fn default() -> Self { + Self { + // Pre-intern the byte array type + bytes: Type::Alg(AlgebraicType::bytes()), + // Pre-intern the identity type + ident: Type::Alg(AlgebraicType::identity()), + // All other composite types are interned on the fly + types: vec![], + // Intern identifiers on the fly + names: StringInterner::new(), + } + } +} + +#[derive(Debug, Error)] +#[error("Invalid type id {0}")] +pub struct InvalidTypeId(TyId); + +impl TyCtx { + /// Return a wrapped [Type::BOOL] + pub fn bool(&self) -> TypeWithCtx { + TypeWithCtx(&Type::BOOL, self) + } + + /// Return a wrapped [Type::I8] + pub fn i8(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I8, self) + } + + /// Return a wrapped [Type::U8] + pub fn u8(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U8, self) + } + + /// Return a wrapped [Type::I16] + pub fn i16(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I16, self) + } + + /// Return a wrapped [Type::U16] + pub fn u16(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U16, self) + } + + /// Return a wrapped [Type::I32] + pub fn i32(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I32, self) + } + + /// Return a wrapped [Type::U32] + pub fn u32(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U32, self) + } + + /// Return a wrapped [Type::I64] + pub fn i64(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I64, self) + } + + /// Return a wrapped [Type::U64] + pub fn u64(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U64, self) + } + + /// Return a wrapped [Type::I128] + pub fn i128(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I128, self) + } + + /// Return a wrapped [Type::U128] + pub fn u128(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U128, self) + } + + /// Return a wrapped [Type::I256] + pub fn i256(&self) -> TypeWithCtx { + TypeWithCtx(&Type::I256, self) + } + + /// Return a wrapped [Type::U256] + pub fn u256(&self) -> TypeWithCtx { + TypeWithCtx(&Type::U256, self) + } + + /// Return a wrapped [Type::F32] + pub fn f32(&self) -> TypeWithCtx { + TypeWithCtx(&Type::F32, self) + } + + /// Return a wrapped [Type::F64] + pub fn f64(&self) -> TypeWithCtx { + TypeWithCtx(&Type::F64, self) + } + + /// Return a wrapped [Type::STR] + pub fn str(&self) -> TypeWithCtx { + TypeWithCtx(&Type::STR, self) + } + + /// Return a wrapped [AlgebraicType::bytes()] + pub fn bytes(&self) -> TypeWithCtx { + TypeWithCtx(&self.bytes, self) + } + + /// Return a wrapped [AlgebraicType::identity()] + pub fn ident(&self) -> TypeWithCtx { + TypeWithCtx(&self.ident, self) + } + + /// Try to resolve an id to its [Type]. + /// Return a resolution error if not found. + pub fn try_resolve(&self, id: TyId) -> Result { + match id { + TyId::BOOL => { + // Resolve the primitive type Bool + Ok(self.bool()) + } + TyId::I8 => { + // Resolve the primitive type I8 + Ok(self.i8()) + } + TyId::U8 => { + // Resolve the primitive type U8 + Ok(self.u8()) + } + TyId::I16 => { + // Resolve the primitive type I16 + Ok(self.i16()) + } + TyId::U16 => { + // Resolve the primitive type U16 + Ok(self.u16()) + } + TyId::I32 => { + // Resolve the primitive type I32 + Ok(self.i32()) + } + TyId::U32 => { + // Resolve the primitive type U32 + Ok(self.u32()) + } + TyId::I64 => { + // Resolve the primitive type I64 + Ok(self.i64()) + } + TyId::U64 => { + // Resolve the primitive type U64 + Ok(self.u64()) + } + TyId::I128 => { + // Resolve the primitive type I128 + Ok(self.i128()) + } + TyId::U128 => { + // Resolve the primitive type U128 + Ok(self.u128()) + } + TyId::I256 => { + // Resolve the primitive type I256 + Ok(self.i256()) + } + TyId::U256 => { + // Resolve the primitive type U256 + Ok(self.u256()) + } + TyId::F32 => { + // Resolve the primitive type F32 + Ok(self.f32()) + } + TyId::F64 => { + // Resolve the primitive type F64 + Ok(self.f64()) + } + TyId::STR => { + // Resolve the primitive type String + Ok(self.str()) + } + TyId::BYTES => { + // Resolve the byte array type + Ok(self.bytes()) + } + TyId::IDENT => { + // Resolve the special identity type + Ok(self.ident()) + } + _ => self + .types + .get(id.0 as usize - TyId::N) + .map(|ty| TypeWithCtx(ty, self)) + .ok_or(InvalidTypeId(id)), + } + } + + /// Resolve a [Symbol] to its name + pub fn resolve_symbol(&self, id: Symbol) -> Option<&str> { + self.names.resolve(id) + } + + /// Add an [AlgebraicType] to the context and return a [TyId] for it. + /// The [TyId] is not guaranteed to be unique to the type. + /// However for primitive types it will be. + pub fn add_algebraic_type(&mut self, ty: &AlgebraicType) -> TyId { + match ty { + AlgebraicType::Bool => { + // Bool -> BOOL + TyId::BOOL + } + AlgebraicType::I8 => { + // I8 -> I8 + TyId::I8 + } + AlgebraicType::U8 => { + // U8 -> U8 + TyId::U8 + } + AlgebraicType::I16 => { + // I16 -> I16 + TyId::I16 + } + AlgebraicType::U16 => { + // U16 -> U16 + TyId::U16 + } + AlgebraicType::I32 => { + // I32 -> I32 + TyId::I32 + } + AlgebraicType::U32 => { + // U32 -> U32 + TyId::U32 + } + AlgebraicType::I64 => { + // I64 -> I64 + TyId::I64 + } + AlgebraicType::U64 => { + // U64 -> U64 + TyId::U64 + } + AlgebraicType::I128 => { + // I128 -> I128 + TyId::I128 + } + AlgebraicType::U128 => { + // U128 -> U128 + TyId::U128 + } + AlgebraicType::I256 => { + // I256 -> I256 + TyId::I256 + } + AlgebraicType::U256 => { + // U256 -> U256 + TyId::U256 + } + AlgebraicType::F32 => { + // F32 -> F32 + TyId::F32 + } + AlgebraicType::F64 => { + // F64 -> F64 + TyId::F64 + } + AlgebraicType::String => { + // String -> STR + TyId::STR + } + AlgebraicType::Array(ty) if ty.elem_ty.is_u8() => { + // [u8] -> BYTES + TyId::BYTES + } + AlgebraicType::Product(ty) if ty.is_identity() => { + // { __identity_bytes: [u8] } -> IDENT + TyId::IDENT + } + _ => { + let n = self.types.len() + TyId::N; + self.types.push(Type::Alg(ty.clone())); + TyId(n as u32) + } + } + } + + /// Add a relvar or table type to the context and return a [TyId] for it. + /// The [TyId] is not guaranteed to be unique to the type. + pub fn add_var_type(&mut self, table_id: TableId, fields: Vec<(Symbol, TyId)>) -> TyId { + let n = self.types.len() + TyId::N; + self.types.push(Type::Var(table_id, fields.into_boxed_slice())); + TyId(n as u32) + } + + /// Add a derived row type to the context and return a [TyId] for it. + /// The [TyId] is not guaranteed to be unique to the type. + pub fn add_row_type(&mut self, fields: Vec<(Symbol, TyId)>) -> TyId { + let n = self.types.len() + TyId::N; + self.types.push(Type::Row(fields.into_boxed_slice())); + TyId(n as u32) + } + + /// Generate a [Symbol] from a string + pub fn gen_symbol(&mut self, name: impl AsRef) -> Symbol { + self.names.get_or_intern(name) + } + + /// Get an already generated [Symbol] + pub fn get_symbol(&self, name: impl AsRef) -> Option { + self.names.get(name) + } + + /// Are these types structurally equivalent? + pub fn eq(&self, a: TyId, b: TyId) -> Result { + if a.0 < TyId::N as u32 || b.0 < TyId::N as u32 { + return Ok(a == b); + } + match (&*self.try_resolve(a)?, &*self.try_resolve(b)?) { + (Type::Alg(a), Type::Alg(b)) => Ok(a == b), + (Type::Var(a, _), Type::Var(b, _)) => Ok(a == b), + (Type::Row(a), Type::Row(b)) => Ok(a.len() == b.len() && { + for (i, (name, id)) in a.iter().enumerate() { + if name != &b[i].0 || !self.eq(*id, b[i].1)? { + return Ok(false); + } + } + true + }), + _ => Ok(false), + } + } +} + +/// A type wrapped with its typing context +#[derive(Debug)] +pub struct TypeWithCtx<'a>(&'a Type, &'a TyCtx); + +impl Deref for TypeWithCtx<'_> { + type Target = Type; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl TypeWithCtx<'_> { + /// Expect a type compatible with this binary operator + pub fn expect_op(&self, op: BinOp) -> Result<(), InvalidOp> { + if self.0.is_compatible_with(op) { + return Ok(()); + } + Err(InvalidOp::new(op, self)) + } + + /// Expect a relvar or base relation type + pub fn expect_relvar(&self) -> Result { + match self.0 { + Type::Var(_, fields) => Ok(RelType { fields }), + Type::Row(_) | Type::Alg(_) => Err(ExpectedRelvar), + } + } + + /// Expect a scalar or column type, not a relation type + pub fn expect_scalar(&self) -> Result<&AlgebraicType, ExpectedScalar> { + match self.0 { + Type::Alg(t) => Ok(t), + Type::Var(..) | Type::Row(..) => Err(ExpectedScalar), + } + } + + /// Expect a relation, not a scalar or column type + pub fn expect_relation(&self) -> Result { + match self.0 { + Type::Var(_, fields) | Type::Row(fields) => Ok(RelType { fields }), + Type::Alg(_) => Err(ExpectedRelation::new(self)), + } + } +} + +/// The error type of [TypeWithCtx::expect_relvar()] +pub struct ExpectedRelvar; + +/// The error type of [TypeWithCtx::expect_scalar()] +pub struct ExpectedScalar; + +impl<'a> Display for TypeWithCtx<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.0 { + Type::Alg(ty) => write!(f, "{}", fmt_algebraic_type(ty)), + Type::Var(_, fields) | Type::Row(fields) => { + const UNKNOWN: &str = "UNKNOWN"; + write!(f, "(")?; + let (symbol, id) = &fields[0]; + let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN); + match self.1.try_resolve(*id) { + Ok(ty) => { + write!(f, "{}: {}", name, ty)?; + } + Err(_) => { + write!(f, "{}: {}", name, UNKNOWN)?; + } + }; + for (symbol, id) in &fields[1..] { + let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN); + match self.1.try_resolve(*id) { + Ok(ty) => { + write!(f, "{}: {}", name, ty)?; + } + Err(_) => { + write!(f, "{}: {}", name, UNKNOWN)?; + } + }; + } + write!(f, ")") + } + } + } +} + +/// Represents a non-scalar or column type +#[derive(Debug)] +pub struct RelType<'a> { + fields: &'a [(Symbol, TyId)], +} + +impl<'a> RelType<'a> { + /// Returns an iterator over the field names and types of this row type + pub fn iter(&'a self) -> impl Iterator + '_ { + self.fields.iter().enumerate().map(|(i, (name, ty))| (i, *name, *ty)) + } + + /// Find the position and type of a field in this row type if it exists + pub fn find(&'a self, name: Symbol) -> Option<(usize, TyId)> { + self.iter() + .find(|(_, field, _)| *field == name) + .map(|(i, _, ty)| (i, ty)) + } +} + +/// A typing environment for an expression. +/// It binds in scope variables to their respective types. +#[derive(Debug, Clone, Default)] +pub struct TyEnv(HashMap); + +impl TyEnv { + /// Adds a new variable binding to the environment. + /// Returns the old binding if the name was already in scope. + pub fn add(&mut self, name: Symbol, ty: TyId) -> Option { + self.0.insert(name, ty) + } + + /// Find a name in the environment + pub fn find(&self, name: Symbol) -> Option { + self.0.get(&name).copied() + } +} diff --git a/crates/planner/src/lib.rs b/crates/planner/src/lib.rs deleted file mode 100644 index 82d25f24e8..0000000000 --- a/crates/planner/src/lib.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod logical; diff --git a/crates/planner/src/logical/ty.rs b/crates/planner/src/logical/ty.rs deleted file mode 100644 index a4046646e1..0000000000 --- a/crates/planner/src/logical/ty.rs +++ /dev/null @@ -1,340 +0,0 @@ -use std::{ - collections::HashMap, - fmt::{Display, Formatter}, - ops::Deref, -}; - -use spacetimedb_lib::AlgebraicType; -use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; -use spacetimedb_sql_parser::ast::BinOp; -use string_interner::{backend::StringBackend, symbol::SymbolU32, StringInterner}; -use thiserror::Error; - -use crate::static_assert_size; - -use super::errors::{ExpectedRelation, InvalidOp}; - -/// When type checking a [super::expr::RelExpr], -/// types are stored in a typing context [TyCtx]. -/// It will then hold references, in the form of [TyId]s, -/// to the types defined in the [TyCtx]. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct TyId(u32); - -impl TyId { - /// The number of primitive types whose [TyId]s are statically defined. - pub const N: usize = 18; - - /// The static type id for Bool - /// The value is determined by [TyCtx::default()] - pub const BOOL: Self = Self(0); - - /// The static type id for U64 - /// The value is determined by [TyCtx::default()] - pub const U64: Self = Self(8); - - /// The static type id for String - /// The value is determined by [TyCtx::default()] - pub const STR: Self = Self(15); - - /// the static type id for a byte array - /// The value is determined by [TyCtx::default()] - pub const BYTES: Self = Self(16); - - /// Return the [Type] for this id with its typing context. - /// Return an error if the id is not valid for the context. - pub fn try_with_ctx(self, ctx: &TyCtx) -> Result { - ctx.try_resolve(self) - } -} - -impl Display for TyId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -/// A symbol for names or identifiers in an expression tree -pub type Symbol = SymbolU32; - -/// The type of a relation or scalar expression -#[derive(Debug)] -pub enum Type { - /// A base relation - Var(Box<[(Symbol, TyId)]>), - /// A derived relation - Row(Box<[(Symbol, TyId)]>), - /// A column type - Alg(AlgebraicType), -} - -static_assert_size!(Type, 24); - -impl Type { - /// Is this type compatible with this binary operator? - pub fn is_compatible_with(&self, op: BinOp) -> bool { - match (op, self) { - (BinOp::And | BinOp::Or, Type::Alg(AlgebraicType::Bool)) => true, - (BinOp::And | BinOp::Or, _) => false, - (BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, Type::Alg(t)) => { - t.is_bool() - || t.is_integer() - || t.is_float() - || t.is_string() - || t.is_bytes() - || t.is_identity() - || t.is_address() - } - (BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, _) => false, - } - } -} - -/// When type checking a [super::expr::RelExpr], -/// types are stored in a typing context [TyCtx]. -/// It will then hold references, in the form of [TyId]s, -/// to the types defined in the [TyCtx]. -#[derive(Debug)] -pub struct TyCtx { - types: Vec, - names: StringInterner, -} - -impl Default for TyCtx { - fn default() -> Self { - Self { - names: StringInterner::new(), - types: vec![ - Type::Alg(AlgebraicType::Bool), - Type::Alg(AlgebraicType::I8), - Type::Alg(AlgebraicType::U8), - Type::Alg(AlgebraicType::I16), - Type::Alg(AlgebraicType::U16), - Type::Alg(AlgebraicType::I32), - Type::Alg(AlgebraicType::U32), - Type::Alg(AlgebraicType::I64), - Type::Alg(AlgebraicType::U64), - Type::Alg(AlgebraicType::I128), - Type::Alg(AlgebraicType::U128), - Type::Alg(AlgebraicType::I256), - Type::Alg(AlgebraicType::U256), - Type::Alg(AlgebraicType::F32), - Type::Alg(AlgebraicType::F64), - Type::Alg(AlgebraicType::String), - Type::Alg(AlgebraicType::bytes()), - Type::Alg(AlgebraicType::identity()), - ], - } - } -} - -#[derive(Debug, Error)] -#[error("Invalid type id {0}")] -pub struct InvalidTypeId(TyId); - -impl TyCtx { - /// Try to resolve an id to its [Type]. - /// Return a resolution error if not found. - pub fn try_resolve(&self, id: TyId) -> Result { - self.types - .get(id.0 as usize) - .map(|ty| TypeWithCtx(ty, self)) - .ok_or(InvalidTypeId(id)) - } - - /// Resolve a [Symbol] to its name - pub fn resolve_symbol(&self, id: Symbol) -> Option<&str> { - self.names.resolve(id) - } - - /// Add a type to the context and return a [TyId] for it. - /// The [TyId] is not guaranteed to be unique to the type. - /// However for primitive types it will be. - pub fn add(&mut self, ty: Type) -> TyId { - if let Type::Alg(t) = &ty { - for i in 0..TyId::N { - if let Type::Alg(s) = &self.types[i] { - if s == t { - return TyId(i as u32); - } - } - } - } - self.types.push(ty); - let n = self.types.len() - 1; - TyId(n as u32) - } - - /// Generate a [Symbol] from a string - pub fn gen_symbol(&mut self, name: impl AsRef) -> Symbol { - self.names.get_or_intern(name) - } - - /// Get an already generated [Symbol] - pub fn get_symbol(&self, name: impl AsRef) -> Option { - self.names.get(name) - } - - /// A wrapped [AlgebraicType::Bool] type - pub fn bool(&self) -> TypeWithCtx { - TypeWithCtx(&Type::Alg(AlgebraicType::Bool), self) - } - - /// A wrapped [AlgebraicType::String] type - pub fn str(&self) -> TypeWithCtx { - TypeWithCtx(&Type::Alg(AlgebraicType::String), self) - } - - /// A wrapped [AlgebraicType::U64] type - pub fn u64(&self) -> TypeWithCtx { - TypeWithCtx(&Type::Alg(AlgebraicType::U64), self) - } - - /// A wrapped [AlgebraicType::bytes()] type - pub fn bytes(&self) -> TypeWithCtx { - TypeWithCtx(&self.types[TyId::BYTES.0 as usize], self) - } - - /// Are these types structurally equivalent? - pub fn eq(&self, a: TyId, b: TyId) -> Result { - if a.0 < TyId::N as u32 || b.0 < TyId::N as u32 { - return Ok(a == b); - } - match (&*self.try_resolve(a)?, &*self.try_resolve(b)?) { - (Type::Alg(a), Type::Alg(b)) => Ok(a == b), - (Type::Var(a), Type::Var(b)) | (Type::Row(a), Type::Row(b)) => Ok(a.len() == b.len() && { - for (i, (name, id)) in a.iter().enumerate() { - if name != &b[i].0 || !self.eq(*id, b[i].1)? { - return Ok(false); - } - } - true - }), - _ => Ok(false), - } - } -} - -/// A type wrapped with its typing context -#[derive(Debug)] -pub struct TypeWithCtx<'a>(&'a Type, &'a TyCtx); - -impl Deref for TypeWithCtx<'_> { - type Target = Type; - - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl TypeWithCtx<'_> { - /// Expect a type compatible with this binary operator - pub fn expect_op(&self, op: BinOp) -> Result<(), InvalidOp> { - if self.0.is_compatible_with(op) { - return Ok(()); - } - Err(InvalidOp::new(op, self)) - } - - /// Expect a relvar or base relation type - pub fn expect_relvar(&self) -> Result { - match self.0 { - Type::Var(fields) => Ok(RelType { fields }), - Type::Row(_) | Type::Alg(_) => Err(ExpectedRelvar), - } - } - - /// Expect a scalar or column type, not a relation type - pub fn expect_scalar(&self) -> Result<&AlgebraicType, ExpectedScalar> { - match self.0 { - Type::Alg(t) => Ok(t), - Type::Var(_) | Type::Row(_) => Err(ExpectedScalar), - } - } - - /// Expect a relation, not a scalar or column type - pub fn expect_relation(&self) -> Result { - match self.0 { - Type::Var(fields) | Type::Row(fields) => Ok(RelType { fields }), - Type::Alg(_) => Err(ExpectedRelation::new(self)), - } - } -} - -/// The error type of [TypeWithCtx::expect_relvar()] -pub struct ExpectedRelvar; - -/// The error type of [TypeWithCtx::expect_scalar()] -pub struct ExpectedScalar; - -impl<'a> Display for TypeWithCtx<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.0 { - Type::Alg(ty) => write!(f, "{}", fmt_algebraic_type(ty)), - Type::Var(fields) | Type::Row(fields) => { - const UNKNOWN: &str = "UNKNOWN"; - write!(f, "(")?; - let (symbol, id) = &fields[0]; - let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN); - match self.1.try_resolve(*id) { - Ok(ty) => { - write!(f, "{}: {}", name, ty)?; - } - Err(_) => { - write!(f, "{}: {}", name, UNKNOWN)?; - } - }; - for (symbol, id) in &fields[1..] { - let name = self.1.resolve_symbol(*symbol).unwrap_or(UNKNOWN); - match self.1.try_resolve(*id) { - Ok(ty) => { - write!(f, "{}: {}", name, ty)?; - } - Err(_) => { - write!(f, "{}: {}", name, UNKNOWN)?; - } - }; - } - write!(f, ")") - } - } - } -} - -/// Represents a non-scalar or column type -#[derive(Debug)] -pub struct RelType<'a> { - fields: &'a [(Symbol, TyId)], -} - -impl<'a> RelType<'a> { - /// Returns an iterator over the field names and types of this row type - pub fn iter(&'a self) -> impl Iterator + '_ { - self.fields.iter().enumerate().map(|(i, (name, ty))| (i, *name, *ty)) - } - - /// Find the position and type of a field in this row type if it exists - pub fn find(&'a self, name: Symbol) -> Option<(usize, TyId)> { - self.iter() - .find(|(_, field, _)| *field == name) - .map(|(i, _, ty)| (i, ty)) - } -} - -/// A typing environment for an expression. -/// It binds in scope variables to their respective types. -#[derive(Debug, Clone, Default)] -pub struct TyEnv(HashMap); - -impl TyEnv { - /// Adds a new variable binding to the environment. - /// Returns the old binding if the name was already in scope. - pub fn add(&mut self, name: Symbol, ty: TyId) -> Option { - self.0.insert(name, ty) - } - - /// Find a name in the environment - pub fn find(&self, name: Symbol) -> Option { - self.0.get(&name).copied() - } -} diff --git a/crates/sql-parser/src/ast/mod.rs b/crates/sql-parser/src/ast/mod.rs index 7487c27e56..43e7e9ebce 100644 --- a/crates/sql-parser/src/ast/mod.rs +++ b/crates/sql-parser/src/ast/mod.rs @@ -61,31 +61,15 @@ pub enum SqlExpr { Bin(Box, Box, BinOp), } -/// A SQL identifier or named reference +/// A SQL identifier or named reference. +/// Currently case sensitive. #[derive(Debug, Clone)] -pub struct SqlIdent { - pub name: String, - pub case_sensitive: bool, -} +pub struct SqlIdent(pub Box); +/// Case insensitivity should be implemented here if at all impl From for SqlIdent { - fn from(value: Ident) -> Self { - match value { - Ident { - value: name, - quote_style: None, - } => SqlIdent { - name, - case_sensitive: false, - }, - Ident { - value: name, - quote_style: Some(_), - } => SqlIdent { - name, - case_sensitive: true, - }, - } + fn from(Ident { value, .. }: Ident) -> Self { + SqlIdent(value.into_boxed_str()) } } @@ -95,11 +79,11 @@ pub enum SqlLiteral { /// A boolean constant Bool(bool), /// A hex value like 0xFF or x'FF' - Hex(String), + Hex(Box), /// An integer or float value - Num(String), + Num(Box), /// A string value - Str(String), + Str(Box), } /// Binary infix operators diff --git a/crates/sql-parser/src/parser/mod.rs b/crates/sql-parser/src/parser/mod.rs index 4d01f30934..dc2b49a531 100644 --- a/crates/sql-parser/src/parser/mod.rs +++ b/crates/sql-parser/src/parser/mod.rs @@ -212,9 +212,9 @@ pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult { pub(crate) fn parse_literal(value: Value) -> SqlParseResult { match value { Value::Boolean(v) => Ok(SqlLiteral::Bool(v)), - Value::Number(v, _) => Ok(SqlLiteral::Num(v)), - Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s)), - Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s)), + Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())), + Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())), + Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())), _ => Err(SqlUnsupported::Literal(value).into()), } }