Skip to content

Commit

Permalink
feat: Type check SQL DML statements
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-spacetime committed Sep 24, 2024
1 parent eadea95 commit 9c03e47
Show file tree
Hide file tree
Showing 7 changed files with 411 additions and 16 deletions.
2 changes: 1 addition & 1 deletion crates/planner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ license-file = "LICENSE"
derive_more.workspace = true
thiserror.workspace = true
spacetimedb-lib.workspace = true
spacetimedb-primitives.workspace = true
spacetimedb-sats.workspace = true
spacetimedb-schema.workspace = true
spacetimedb-sql-parser.workspace = true

[dev-dependencies]
spacetimedb-lib.workspace = true
spacetimedb-primitives.workspace = true
126 changes: 119 additions & 7 deletions crates/planner/src/logical/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,122 @@ pub trait SchemaView {
fn schema(&self, name: &str, case_sensitive: bool) -> Option<Arc<TableSchema>>;
}

pub trait TypeChecker {
type Ast;
type Set;

fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr>;

fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr>;

fn type_from(ctx: &mut TyCtx, from: SqlFrom<Self::Ast>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
match from {
SqlFrom::Expr(expr, None) => Self::type_rel(ctx, expr, tx),
SqlFrom::Expr(expr, Some(alias)) => {
let (expr, _) = Self::type_rel(ctx, expr, tx)?;
let ty = expr.ty_id();
Ok((expr, vec![(alias.name, ty)].into()))
}
SqlFrom::Join(r, alias, joins) => {
let (mut vars, mut args, mut exprs) = (Vars::default(), Vec::new(), Vec::new());

let (r, _) = Self::type_rel(ctx, r, tx)?;
let ty = r.ty_id();

args.push(r);
vars.push((alias.name, ty));

for join in joins {
let (r, _) = Self::type_rel(ctx, join.expr, tx)?;
let ty = r.ty_id();

args.push(r);
vars.push((join.alias.name, ty));

if let Some(on) = join.on {
exprs.push(type_expr(ctx, &vars, on, Some(TyId::BOOL))?);
}
}
let types = vars.iter().map(|(_, ty)| *ty).collect();
let ty = Type::Tup(types);
let input = RelExpr::Join(args.into(), ctx.add(ty));
Ok((RelExpr::select(input, vars.clone(), exprs), vars))
}
}
}

fn type_rel(ctx: &mut TyCtx, expr: ast::RelExpr<Self::Ast>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
match expr {
ast::RelExpr::Var(var) => {
let schema = tx
.schema(&var.name, var.case_sensitive)
.ok_or_else(|| Unresolved::table(&var.name))
.map_err(TypingError::from)?;
let mut types = Vec::new();
for ColumnSchema { col_name, col_type, .. } in schema.columns() {
let ty = Type::Alg(col_type.clone());
let id = ctx.add(ty);
types.push((col_name.to_string(), id));
}
let ty = Type::Var(types.into_boxed_slice());
let id = ctx.add(ty);
Ok((RelExpr::RelVar(schema, id), vec![(var.name, id)].into()))
}
ast::RelExpr::Ast(ast) => Ok((Self::type_ast(ctx, *ast, tx)?, Vars::default())),
}
}
}

/// Type checker for subscriptions
struct SubChecker;

impl TypeChecker for SubChecker {
type Ast = SqlAst;
type Set = SqlAst;

fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr> {
Self::type_set(ctx, ast, tx)
}

fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr> {
match ast {
SqlAst::Union(a, b) => {
let a = type_ast(ctx, *a, tx)?;
let b = type_ast(ctx, *b, tx)?;
assert_eq_types(a.ty_id().try_with_ctx(ctx)?, b.ty_id().try_with_ctx(ctx)?)?;
Ok(RelExpr::Union(Box::new(a), Box::new(b)))
}
SqlAst::Minus(a, b) => {
let a = type_ast(ctx, *a, tx)?;
let b = type_ast(ctx, *b, tx)?;
assert_eq_types(a.ty_id().try_with_ctx(ctx)?, b.ty_id().try_with_ctx(ctx)?)?;
Ok(RelExpr::Minus(Box::new(a), Box::new(b)))
}
SqlAst::Select(SqlSelect {
project,
from,
filter: None,
}) => {
let (arg, vars) = type_from(ctx, from, tx)?;
type_proj(ctx, project, arg, vars)
}
SqlAst::Select(SqlSelect {
project,
from,
filter: Some(expr),
}) => {
let (from, vars) = type_from(ctx, from, tx)?;
let arg = type_select(ctx, expr, from, vars.clone())?;
type_proj(ctx, project, arg, vars.clone())
}
}
}
}

/// Parse and type check a subscription query
pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<RelExpr> {
let mut ctx = TyCtx::default();
let expr = type_ast(&mut ctx, parse_subscription(sql)?, tx)?;
let expr = SubChecker::type_ast(&mut ctx, parse_subscription(sql)?, tx)?;
expect_table_type(&ctx, expr)
}

Expand Down Expand Up @@ -128,13 +240,13 @@ fn type_rel(ctx: &mut TyCtx, expr: ast::RelExpr<SqlAst>, tx: &impl SchemaView) -
}

/// Type check and lower a [SqlExpr]
fn type_select(ctx: &mut TyCtx, expr: SqlExpr, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
pub(crate) fn type_select(ctx: &mut TyCtx, expr: SqlExpr, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
let exprs = vec![type_expr(ctx, &vars, expr, Some(TyId::BOOL))?];
Ok(RelExpr::select(input, vars, exprs))
}

/// Type check and lower a [ast::Project]
fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
pub(crate) fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
match proj {
ast::Project::Star(None) => Ok(input),
ast::Project::Star(Some(var)) => {
Expand Down Expand Up @@ -167,7 +279,7 @@ fn type_proj(ctx: &mut TyCtx, proj: ast::Project, input: RelExpr, vars: Vars) ->
}

/// Type check and lower a [SqlExpr] into a logical [Expr].
fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) -> TypingResult<Expr> {
pub(crate) fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) -> TypingResult<Expr> {
match (expr, expected) {
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(TyId::BOOL)) => Ok(Expr::bool(v)),
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(id)) => {
Expand Down Expand Up @@ -195,7 +307,7 @@ fn type_expr(ctx: &TyCtx, vars: &Vars, expr: SqlExpr, expected: Option<TyId>) ->
}

/// Parses a source text literal as a particular type
fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
pub(crate) fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
let err = |v, ty| TypingError::from(ConstraintViolation::lit(v, ty));
match ctx.try_resolve(id)? {
ty @ Type::Alg(AlgebraicType::I8) => v
Expand Down Expand Up @@ -260,7 +372,7 @@ fn parse(ctx: &TyCtx, v: String, id: TyId) -> TypingResult<AlgebraicValue> {
}

/// Returns a type constraint violation for an unexpected type
fn unexpected_type(expected: TypeWithCtx<'_>, inferred: TypeWithCtx<'_>) -> TypingError {
pub(crate) fn unexpected_type(expected: TypeWithCtx<'_>, inferred: TypeWithCtx<'_>) -> TypingError {
ConstraintViolation::eq(expected, inferred).into()
}

Expand All @@ -282,7 +394,7 @@ fn expect_op_type(ctx: &TyCtx, op: BinOp, expr: Expr) -> TypingResult<Expr> {
}
}

fn assert_eq_types(a: TypeWithCtx<'_>, b: TypeWithCtx<'_>) -> TypingResult<()> {
pub(crate) fn assert_eq_types(a: TypeWithCtx<'_>, b: TypeWithCtx<'_>) -> TypingResult<()> {
if a == b {
Ok(())
} else {
Expand Down
10 changes: 10 additions & 0 deletions crates/planner/src/logical/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ pub enum Unsupported {
UnqualifiedProjectExpr,
}

#[derive(Error, Debug)]
#[error("Inserting a row with {values} values into {table} which has {fields} fields")]
pub struct InsertError {
pub table: String,
pub values: usize,
pub fields: usize,
}

#[derive(Error, Debug)]
pub enum TypingError {
#[error(transparent)]
Expand All @@ -100,5 +108,7 @@ pub enum TypingError {
#[error(transparent)]
InvalidTyId(#[from] InvalidTyId),
#[error(transparent)]
Insert(#[from] InsertError),
#[error(transparent)]
ParseError(#[from] SqlParseError),
}
1 change: 1 addition & 0 deletions crates/planner/src/logical/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bind;
pub mod errors;
pub mod expr;
pub mod stmt;
pub mod ty;

/// Asserts that `$ty` is `$size` bytes in `static_assert_size($ty, $size)`.
Expand Down
Loading

0 comments on commit 9c03e47

Please sign in to comment.