Skip to content

Commit

Permalink
aliases required
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-spacetime committed Sep 13, 2024
1 parent bdd4d0b commit 6f36ef3
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 206 deletions.
246 changes: 105 additions & 141 deletions crates/planner/src/logical/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,31 @@ use spacetimedb_sql_parser::{
parser::sub::parse_subscription,
};

use super::errors::{ResolutionError, TypeError};
use super::expr::{Expr, Project, Ref, RelExpr, Select, Type, Vars};
use super::errors::{ConstraintViolation, ResolutionError, TypingError, Unsupported};
use super::expr::{Expr, Ref, RelExpr, Type, Vars};

pub type TypeResult<T> = core::result::Result<T, TypeError>;
pub type TypingResult<T> = core::result::Result<T, TypingError>;

pub trait SchemaView {
fn schema(&self, name: &str, case_sensitive: bool) -> Option<Arc<TableSchema>>;
}

pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypeResult<RelExpr> {
let (expr, _) = bind_and_check(parse_subscription(sql)?, tx)?;
expect_table_type("", expr.ty())?;
Ok(expr)
/// Parse and type check a subscription query
pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<RelExpr> {
expect_table_type(bind_and_check(parse_subscription(sql)?, tx)?)
}

/// Bind variables, check types, and lower a [SqlAst] into a [RelExpr]
pub fn bind_and_check(expr: SqlAst, tx: &impl SchemaView) -> TypeResult<(RelExpr, Vars)> {
pub fn bind_and_check(expr: SqlAst, tx: &impl SchemaView) -> TypingResult<RelExpr> {
match expr {
SqlAst::Union(a, b) => {
let (a, _) = bind_and_check(*a, tx)?;
let (b, vars) = bind_and_check(*b, tx)?;
Ok((RelExpr::Union(Box::new(a), Box::new(b)), vars))
}
SqlAst::Minus(a, b) => {
let (a, _) = bind_and_check(*a, tx)?;
let (b, vars) = bind_and_check(*b, tx)?;
Ok((RelExpr::Minus(Box::new(a), Box::new(b)), vars))
}
SqlAst::Union(a, b) => Ok(RelExpr::Union(
Box::new(bind_and_check(*a, tx)?),
Box::new(bind_and_check(*b, tx)?),
)),
SqlAst::Minus(a, b) => Ok(RelExpr::Minus(
Box::new(bind_and_check(*a, tx)?),
Box::new(bind_and_check(*b, tx)?),
)),
SqlAst::Select(SqlSelect {
project,
from,
Expand All @@ -52,28 +49,22 @@ pub fn bind_and_check(expr: SqlAst, tx: &impl SchemaView) -> TypeResult<(RelExpr
from,
filter: Some(expr),
}) => {
let (arg, vars) = bind_and_check_from(from, tx)?;
let (arg, vars) = bind_and_check_sel(expr, arg, vars)?;
bind_and_check_proj(project, arg, vars)
let (from, vars) = bind_and_check_from(from, tx)?;
let arg = bind_and_check_sel(expr, from, vars.clone())?;
bind_and_check_proj(project, arg, vars.clone())
}
}
}

/// Bind variables and type check the relation expression in a FROM clause
pub fn bind_and_check_from(from: SqlFrom<SqlAst>, tx: &impl SchemaView) -> TypeResult<(RelExpr, Vars)> {
pub fn bind_and_check_from(from: SqlFrom<SqlAst>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
match from {
// If the input relation is not aliased,
// do not generate new variable bindings.
SqlFrom::Expr(expr, None) => bind_and_check_rel(expr, tx),
// If the input relation is aliased,
// create a single variable binding for the alias.
SqlFrom::Expr(expr, Some(alias)) => {
let (expr, _) = bind_and_check_rel(expr, tx)?;
let ty = expr.ty().clone();
Ok((expr, vec![(alias.name, ty)].into()))
}
// If the input relation is a join,
// create new bindings from the input aliases.
SqlFrom::Join(r, alias, joins) => {
let (mut vars, mut args, mut exprs) = (Vars::new(), Vec::new(), Vec::new());

Expand All @@ -94,169 +85,134 @@ pub fn bind_and_check_from(from: SqlFrom<SqlAst>, tx: &impl SchemaView) -> TypeR
exprs.push(check_and_lower_expr(&vars, on, Some(&Type::BOOL))?);
}
}
let ty = Type::Tup(vars.iter().map(|(_, ty)| ty.clone()).collect());
let input = Box::new(RelExpr::Join(args, ty));
let next = vars.clone();
Ok((RelExpr::Select(Select { input, vars, exprs }), next))
let types = vars.iter().map(|(_, ty)| ty.clone()).collect();
let input = RelExpr::Join(args, Type::Tup(types));
Ok((RelExpr::select(input, vars.clone(), exprs), vars))
}
}
}

/// Bind variables and type check a relation expression
pub fn bind_and_check_rel(expr: ast::RelExpr<SqlAst>, tx: &impl SchemaView) -> TypeResult<(RelExpr, Vars)> {
pub fn bind_and_check_rel(expr: ast::RelExpr<SqlAst>, tx: &impl SchemaView) -> TypingResult<(RelExpr, Vars)> {
match expr {
// A relvar brings new variables into scope.
ast::RelExpr::Var(var) => tx
.schema(&var.name, var.case_sensitive)
.ok_or_else(|| ResolutionError::unresolved_table(&var.name).into())
.map(|schema| {
(
RelExpr::RelVar(schema.clone(), Type::Var(schema.clone())),
Vars::from(schema.as_ref()),
vec![(var.name, Type::Var(schema))].into(),
)
}),
ast::RelExpr::Ast(ast) => bind_and_check(*ast, tx),
ast::RelExpr::Ast(ast) => Ok((bind_and_check(*ast, tx)?, Vars::new())),
}
}

/// Bind variables and type check a selection
fn bind_and_check_sel(expr: SqlExpr, input: RelExpr, vars: Vars) -> TypeResult<(RelExpr, Vars)> {
Ok((
RelExpr::Select(Select {
input: Box::new(input),
vars: vars.clone(),
exprs: vec![check_and_lower_expr(&vars, expr, Some(&Type::BOOL))?],
}),
vars,
))
fn bind_and_check_sel(expr: SqlExpr, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
let exprs = vec![check_and_lower_expr(&vars, expr, Some(&Type::BOOL))?];
Ok(RelExpr::select(input, vars, exprs))
}

/// Bind variables and type check a projection
fn bind_and_check_proj(proj: ast::Project, input: RelExpr, vars: Vars) -> TypeResult<(RelExpr, Vars)> {
fn bind_and_check_proj(proj: ast::Project, input: RelExpr, vars: Vars) -> TypingResult<RelExpr> {
match proj {
// SELECT * is redundant in this representation
ast::Project::Star(None) => Ok((input, vars)),
ast::Project::Star(None) => Ok(input),
ast::Project::Star(Some(var)) => {
let (i, ty) = vars.expect_var(&var.name, None)?;
// SELECT a.* only valid on table types
let schema = expect_table_type(&var.name, ty)?;
// The output type is that of a relvar
let ty = Type::Var(schema.clone());
// The fields of this relvar are now in scope
let next = Vars::from(schema.as_ref());
let input = Box::new(input);
let ty = ty.clone();
let refs = vec![Ref::Var(i, ty.clone())];
Ok((RelExpr::Proj(Project { input, vars, refs, ty }), next))
Ok(RelExpr::project(input, vars, refs, ty))
}
ast::Project::Exprs(elems) => {
let params = vars;
let mut refs = Vec::new();
let mut vars = Vec::new();
let (mut refs, mut fields) = (Vec::new(), Vec::new());
for ProjectElem(expr, alias) in elems {
match (expr, alias) {
(SqlExpr::Var(var), None) => {
let (i, ty) = params.expect_var(&var.name, None)?;
let ty = expect_col_type(&var.name, ty)?;

refs.push(Ref::Var(i, Type::Alg(ty.clone())));
vars.push((var.name, ty.clone()));
}
(SqlExpr::Var(var), Some(alias)) => {
let (i, ty) = params.expect_var(&var.name, None)?;
let ty = expect_col_type(&var.name, ty)?;

refs.push(Ref::Var(i, Type::Alg(ty.clone())));
vars.push((alias.name, ty.clone()));
}
(SqlExpr::Field(table, field), None) => {
let (i, j, ty) = params.expect_field(&table.name, &field.name, None)?;

refs.push(Ref::Field(i, j, Type::Alg(ty.clone())));
vars.push((field.name, ty.clone()));
}
(SqlExpr::Field(table, field), Some(alias)) => {
let (i, j, ty) = params.expect_field(&table.name, &field.name, None)?;

refs.push(Ref::Field(i, j, Type::Alg(ty.clone())));
vars.push((alias.name, ty.clone()));
}
_ => return Err(TypeError::InvalidProjectExpr),
if let SqlExpr::Var(_) = expr {
return Err(Unsupported::UnqualifiedProjectExpr.into());
}
let SqlExpr::Field(table, field) = expr else {
return Err(Unsupported::ProjectExpr.into());
};
let (i, j, ty) = vars.expect_field(&table.name, &field.name, None)?;
refs.push(Ref::Field(i, j, Type::Alg(ty.clone())));
if let Some(alias) = alias {
fields.push((alias.name, ty.clone()));
} else {
fields.push((field.name, ty.clone()));
}
}
let ty = Type::Row(ProductType::from_iter(vars.iter().map(|(name, t)| {
ProductTypeElement::new_named(t.clone(), name.to_owned().into_boxed_str())
})));
Ok((
RelExpr::Proj(Project {
input: Box::new(input),
vars: params,
refs,
ty,
}),
Vars::from_iter(vars.into_iter().map(|(name, ty)| (name, Type::Alg(ty)))),
))
let ty = Type::Row(ProductType::from_iter(
fields
.into_iter()
.map(|(name, t)| ProductTypeElement::new_named(t, name.into_boxed_str())),
));
Ok(RelExpr::project(input, vars, refs, ty))
}
}
}

/// Type check and lower a [SqlExpr] into a logical [Expr].
fn check_and_lower_expr(params: &Vars, expr: SqlExpr, expected: Option<&Type>) -> TypeResult<Expr> {
fn check_and_lower_expr(vars: &Vars, expr: SqlExpr, expected: Option<&Type>) -> TypingResult<Expr> {
match (expr, expected) {
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(Type::Alg(AlgebraicType::Bool))) => Ok(Expr::bool(v)),
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(t)) => Err(TypeError::unexpected(&Type::BOOL, t)),
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(t)) => Err(ConstraintViolation::eq(&Type::BOOL, t).into()),
(SqlExpr::Lit(SqlLiteral::Str(v)), None | Some(Type::Alg(AlgebraicType::String))) => Ok(Expr::str(v)),
(SqlExpr::Lit(SqlLiteral::Str(_)), Some(t)) => Err(TypeError::unexpected(&Type::STR, t)),
(SqlExpr::Lit(SqlLiteral::Str(_)), Some(t)) => Err(ConstraintViolation::eq(&Type::STR, t).into()),
// Cannot infer the type of numeric literal
(SqlExpr::Lit(SqlLiteral::Num(_)), None) => Err(TypeError::UntypedLiteral),
(SqlExpr::Lit(SqlLiteral::Num(_)), None) => Err(ResolutionError::UntypedLiteral.into()),
(SqlExpr::Lit(SqlLiteral::Num(v)), Some(t)) if t.is_num() => Ok(Expr::num(v, t.clone())),
(SqlExpr::Lit(SqlLiteral::Num(_)), Some(t)) => Err(TypeError::num(t)),
(SqlExpr::Lit(SqlLiteral::Num(_)), Some(t)) => Err(ConstraintViolation::num(t).into()),
// Cannot infer the type of hex literal
(SqlExpr::Lit(SqlLiteral::Hex(_)), None) => Err(TypeError::UntypedLiteral),
(SqlExpr::Lit(SqlLiteral::Hex(_)), None) => Err(ResolutionError::UntypedLiteral.into()),
(SqlExpr::Lit(SqlLiteral::Hex(v)), Some(t)) if t.is_hex() => Ok(Expr::hex(v, t.clone())),
(SqlExpr::Lit(SqlLiteral::Hex(_)), Some(t)) => Err(TypeError::hex(t)),
(SqlExpr::Var(var), expected) => params.expect_var_ref(&var.name, expected),
(SqlExpr::Field(table, field), expected) => params.expect_field_ref(&table.name, &field.name, expected),
(SqlExpr::Lit(SqlLiteral::Hex(_)), Some(t)) => Err(ConstraintViolation::hex(t).into()),
(SqlExpr::Var(var), expected) => vars.expect_var_ref(&var.name, expected),
(SqlExpr::Field(table, field), expected) => vars.expect_field_ref(&table.name, &field.name, expected),
// The operands of and/or must be boolean typed expressions
(SqlExpr::Bin(a, b, op @ BinOp::And | op @ BinOp::Or), None | Some(Type::Alg(AlgebraicType::Bool))) => {
Ok(Expr::Bin(
op,
Box::new(check_and_lower_expr(params, *a, Some(&Type::BOOL))?),
Box::new(check_and_lower_expr(params, *b, Some(&Type::BOOL))?),
Box::new(check_and_lower_expr(vars, *a, Some(&Type::BOOL))?),
Box::new(check_and_lower_expr(vars, *b, Some(&Type::BOOL))?),
))
}
// The operands of binary operators must be the same type
(SqlExpr::Bin(a, b, op), None | Some(Type::Alg(AlgebraicType::Bool))) => match (*a, *b) {
(a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
let a = expect_eq_type(check_and_lower_expr(params, a, None)?)?;
let b = expect_eq_type(check_and_lower_expr(params, b, Some(a.ty()))?)?;
let a = expect_op_type(op, check_and_lower_expr(vars, a, None)?)?;
let b = expect_op_type(op, check_and_lower_expr(vars, b, Some(a.ty()))?)?;
Ok(Expr::Bin(op, Box::new(a), Box::new(b)))
}
},
(SqlExpr::Bin(..), Some(t)) => Err(TypeError::unexpected(&Type::BOOL, t)),
(SqlExpr::Bin(..), Some(t)) => Err(ConstraintViolation::eq(&Type::BOOL, t).into()),
}
}

/// Returns an error if the input type is not a table type [Type::Var]
fn expect_table_type<'a>(name: &str, ty: &'a Type) -> TypeResult<&'a Arc<TableSchema>> {
match ty {
Type::Var(schema) => Ok(schema),
_ => Err(TypeError::table(name)),
}
}

/// Returns an error if the input type is not an algebraic type [Type::Alg]
fn expect_col_type<'a>(name: &str, ty: &'a Type) -> TypeResult<&'a AlgebraicType> {
match ty {
Type::Alg(t) => Ok(t),
_ => Err(TypeError::col(name, ty)),
fn expect_table_type(expr: RelExpr) -> TypingResult<RelExpr> {
match expr.ty() {
Type::Var(_) => Ok(expr),
_ => Err(Unsupported::SubReturnType.into()),
}
}

/// Assert that this type can be compared for equality
fn expect_eq_type(expr: Expr) -> TypeResult<Expr> {
match expr.ty() {
Type::Alg(t) if t.is_bool() || t.is_integer() || t.is_float() || t.is_string() || t.is_bytes() => Ok(expr),
t => Err(TypeError::Eq(t.clone())),
/// Assert that this type is compatible with this operator
fn expect_op_type(op: BinOp, expr: Expr) -> TypingResult<Expr> {
match (op, expr.ty()) {
(BinOp::And | BinOp::Or, Type::Alg(AlgebraicType::Bool)) => Ok(expr),
(BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte, Type::Alg(t)) if t.is_integer() || t.is_float() => Ok(expr),
(BinOp::Eq | BinOp::Ne, Type::Alg(t))
if t.is_bool()
|| t.is_integer()
|| t.is_float()
|| t.is_string()
|| t.is_bytes()
|| t.is_identity()
|| t.is_address() =>
{
Ok(expr)
}
(op, ty) => Err(ConstraintViolation::op(op, ty).into()),
}
}

Expand Down Expand Up @@ -330,12 +286,12 @@ mod tests {
for sql in [
"select * from t",
"select * from t where true",
"select * from t where u32 = 1",
"select * from t where u32 = 1 or str = ''",
"select * from t where t.u32 = 1",
"select * from t where t.u32 = 1 or t.str = ''",
"select * from s as r where r.bytes = 0xABCD",
"select * from (select t.* from t join s)",
"select * from (select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1)",
"select * from (select t.* from t join (select u32 from s) s on t.u32 = s.u32)",
"select * from (select t.* from t join (select s.u32 from s) s on t.u32 = s.u32)",
] {
let result = parse_and_type_sub(sql, &mut tx).inspect_err(|err| {
println!("{}", err.to_string());
Expand Down Expand Up @@ -387,20 +343,28 @@ mod tests {
for sql in [
// Table r does not exist
"select * from r",
// Field u32 is not in scope
"select * from t where u32 = 1",
// Field a does not exist on table t
"select * from t where a = 1",
"select * from t where t.a = 1",
// Field a does not exist on table t
"select * from t as r where r.a = 1",
// Field u32 is not a string
"select * from t where u32 = 'str'",
// Field u32 is not in scope after alias r
"select * from t as r where u32 = 5",
// Field bytes does not exist on table t
"select * from t as r where r.bytes = 0xABCD",
"select * from t where t.u32 = 'str'",
// t is not in scope after alias
"select * from t as r where t.u32 = 5",
// Field u32 is not in scope
"select u32 from t",
// Subscriptions must be typed to a single table
"select t.u32 from t",
// Subscriptions must be typed to a single table
"select * from t join s",
// Product values are not comparable
"select * from t join s on t.row = s.row",
"select * from (select t.* from t join s on t.row = s.row)",
// Subscriptions must be typed to a single table
"select * from (select s.* from t join (select u32 from s) s on t.u32 = s.u32)",
"select * from (select s.* from t join (select s.u32 from s) s on t.u32 = s.u32)",
// Field bytes is no longer in scope
"select * from (select t.* from t join (select s.u32 from s) s on s.bytes = 0xABCD)",
] {
let result = parse_and_type_sub(sql, &mut tx).inspect_err(|err| {
println!("{}", err.to_string());
Expand Down
Loading

0 comments on commit 6f36ef3

Please sign in to comment.