diff --git a/Cargo.lock b/Cargo.lock index b8a3c14fc2..1a8299ab73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4707,6 +4707,7 @@ dependencies = [ "spacetimedb-lib", "spacetimedb-primitives", "spacetimedb-sats", + "spacetimedb-sql-parser", "spacetimedb-testing", "thiserror", "unicode-ident", diff --git a/crates/bindings-csharp/Runtime/Internal/Autogen/RawModuleDefV9.cs b/crates/bindings-csharp/Runtime/Internal/Autogen/RawModuleDefV9.cs index adf37df0b7..7ac231e1d2 100644 --- a/crates/bindings-csharp/Runtime/Internal/Autogen/RawModuleDefV9.cs +++ b/crates/bindings-csharp/Runtime/Internal/Autogen/RawModuleDefV9.cs @@ -25,13 +25,16 @@ public partial class RawModuleDefV9 public System.Collections.Generic.List Types; [DataMember(Name = "misc_exports")] public System.Collections.Generic.List MiscExports; + [DataMember(Name = "row_level_security")] + public System.Collections.Generic.List RowLevelSecurity; public RawModuleDefV9( SpacetimeDB.Internal.Typespace Typespace, System.Collections.Generic.List Tables, System.Collections.Generic.List Reducers, System.Collections.Generic.List Types, - System.Collections.Generic.List MiscExports + System.Collections.Generic.List MiscExports, + System.Collections.Generic.List RowLevelSecurity ) { this.Typespace = Typespace; @@ -39,6 +42,7 @@ public RawModuleDefV9( this.Reducers = Reducers; this.Types = Types; this.MiscExports = MiscExports; + this.RowLevelSecurity = RowLevelSecurity; } public RawModuleDefV9() @@ -48,6 +52,7 @@ public RawModuleDefV9() this.Reducers = new(); this.Types = new(); this.MiscExports = new(); + this.RowLevelSecurity = new(); } } diff --git a/crates/bindings-csharp/Runtime/Internal/Autogen/RawRowLevelSecurityDefV9.cs b/crates/bindings-csharp/Runtime/Internal/Autogen/RawRowLevelSecurityDefV9.cs new file mode 100644 index 0000000000..81c75282e8 --- /dev/null +++ b/crates/bindings-csharp/Runtime/Internal/Autogen/RawRowLevelSecurityDefV9.cs @@ -0,0 +1,34 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN RUST INSTEAD. +// + +#nullable enable + +using System; +using SpacetimeDB; +using System.Collections.Generic; +using System.Runtime.Serialization; + +namespace SpacetimeDB.Internal +{ + [SpacetimeDB.Type] + [DataContract] + public partial class RawRowLevelSecurityDefV9 + { + [DataMember(Name = "sql")] + public string Sql; + + public RawRowLevelSecurityDefV9( + string Sql + ) + { + this.Sql = Sql; + } + + public RawRowLevelSecurityDefV9() + { + this.Sql = ""; + } + + } +} diff --git a/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs b/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs index c909d95dd3..1098530a9e 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/committed_state.rs @@ -11,9 +11,9 @@ use crate::{ system_tables, StColumnRow, StConstraintData, StConstraintRow, StIndexAlgorithm, StIndexRow, StSequenceRow, StTableFields, StTableRow, SystemTable, ST_CLIENT_ID, ST_CLIENT_IDX, ST_COLUMN_ID, ST_COLUMN_IDX, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_IDX, ST_CONSTRAINT_NAME, ST_INDEX_ID, - ST_INDEX_IDX, ST_INDEX_NAME, ST_MODULE_ID, ST_MODULE_IDX, ST_RESERVED_SEQUENCE_RANGE, ST_SCHEDULED_ID, - ST_SCHEDULED_IDX, ST_SEQUENCE_ID, ST_SEQUENCE_IDX, ST_SEQUENCE_NAME, ST_TABLE_ID, ST_TABLE_IDX, - ST_VAR_ID, ST_VAR_IDX, + ST_INDEX_IDX, ST_INDEX_NAME, ST_MODULE_ID, ST_MODULE_IDX, ST_RESERVED_SEQUENCE_RANGE, + ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_IDX, ST_SCHEDULED_ID, ST_SCHEDULED_IDX, ST_SEQUENCE_ID, + ST_SEQUENCE_IDX, ST_SEQUENCE_NAME, ST_TABLE_ID, ST_TABLE_IDX, ST_VAR_ID, ST_VAR_IDX, }, traits::TxData, }, @@ -225,6 +225,10 @@ impl CommittedState { self.create_table(ST_SCHEDULED_ID, schemas[ST_SCHEDULED_IDX].clone()); + self.create_table(ST_ROW_LEVEL_SECURITY_ID, schemas[ST_ROW_LEVEL_SECURITY_IDX].clone()); + + // IMPORTANT: It is crucial that the `st_sequences` table is created last + // Insert the sequences into `st_sequences` let (st_sequences, blob_store) = self.get_table_and_blob_store_or_create(ST_SEQUENCE_ID, &schemas[ST_SEQUENCE_IDX]); diff --git a/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs b/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs index f9b74fc82e..b158e7b99c 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/datastore.rs @@ -61,7 +61,7 @@ pub struct Locking { /// The state of sequence generation in this database. sequence_state: Arc>, /// The address of this database. - database_address: Address, + pub(crate) database_address: Address, } impl Locking { @@ -379,6 +379,10 @@ impl MutTxDatastore for Locking { tx.table_id_from_name(table_name, self.database_address) } + fn table_id_exists_mut_tx(&self, tx: &Self::MutTx, table_id: &TableId) -> bool { + tx.table_name(*table_id).is_some() + } + fn table_name_from_id_mut_tx<'a>( &'a self, ctx: &'a ExecutionContext, @@ -511,10 +515,6 @@ impl MutTxDatastore for Locking { Ok((gens, row_ref.collapse())) } - fn table_id_exists_mut_tx(&self, tx: &Self::MutTx, table_id: &TableId) -> bool { - tx.table_name(*table_id).is_some() - } - fn metadata_mut_tx(&self, tx: &Self::MutTx) -> Result> { let ctx = ExecutionContext::internal(self.database_address); tx.iter(&ctx, ST_MODULE_ID)?.next().map(metadata_from_row).transpose() @@ -927,9 +927,10 @@ mod tests { use super::*; use crate::db::datastore::system_tables::{ system_tables, StColumnRow, StConstraintData, StConstraintFields, StConstraintRow, StIndexAlgorithm, - StIndexFields, StIndexRow, StScheduledFields, StSequenceFields, StSequenceRow, StTableRow, StVarFields, - StVarValue, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, ST_INDEX_ID, - ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, + StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, StSequenceFields, StSequenceRow, + StTableRow, StVarFields, StVarValue, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, + ST_CONSTRAINT_NAME, ST_INDEX_ID, ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, + ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_NAME, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, ST_SEQUENCE_NAME, ST_TABLE_NAME, ST_VAR_ID, ST_VAR_NAME, }; use crate::db::datastore::traits::{IsolationLevel, MutTx}; @@ -944,7 +945,9 @@ mod tests { use spacetimedb_primitives::{col_list, ColId, ScheduleId}; use spacetimedb_sats::{product, AlgebraicType, GroundSpacetimeType}; use spacetimedb_schema::def::{BTreeAlgorithm, ConstraintData, IndexAlgorithm, UniqueConstraintData}; - use spacetimedb_schema::schema::{ColumnSchema, ConstraintSchema, IndexSchema, SequenceSchema}; + use spacetimedb_schema::schema::{ + ColumnSchema, ConstraintSchema, IndexSchema, RowLevelSecuritySchema, SequenceSchema, + }; use spacetimedb_table::table::UniqueConstraintViolation; /// For the first user-created table, sequences in the system tables start @@ -1320,6 +1323,7 @@ mod tests { TableRow { id: ST_CLIENT_ID.into(), name: ST_CLIENT_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: None }, TableRow { id: ST_VAR_ID.into(), name: ST_VAR_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StVarFields::Name.into()) }, TableRow { id: ST_SCHEDULED_ID.into(), name: ST_SCHEDULED_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StScheduledFields::ScheduleId.into()) }, + TableRow { id: ST_ROW_LEVEL_SECURITY_ID.into(), name: ST_ROW_LEVEL_SECURITY_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StRowLevelSecurityFields::Sql.into()) }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_columns()?, map_array([ @@ -1371,6 +1375,9 @@ mod tests { ColRow { table: ST_SCHEDULED_ID.into(), pos: 1, name: "table_id", ty: TableId::get_type() }, ColRow { table: ST_SCHEDULED_ID.into(), pos: 2, name: "reducer_name", ty: AlgebraicType::String }, ColRow { table: ST_SCHEDULED_ID.into(), pos: 3, name: "schedule_name", ty: AlgebraicType::String }, + + ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 0, name: "table_id", ty: TableId::get_type() }, + ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 1, name: "sql", ty: AlgebraicType::String }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_indexes()?, map_array([ @@ -1384,6 +1391,8 @@ mod tests { IndexRow { id: 8, table: ST_VAR_ID.into(), col: col(0), name: "idx_st_var_name_unique", }, IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", }, IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", }, + IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_btree_table_id"}, + IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_sql_unique"}, ])); let start = FIRST_NON_SYSTEM_ID as i128; #[rustfmt::skip] @@ -1412,6 +1421,7 @@ mod tests { ConstraintRow { constraint_id: 8, table_id: ST_VAR_ID.into(), unique_columns: col(0), constraint_name: "ct_st_var_name_unique" }, ConstraintRow { constraint_id: 9, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(0), constraint_name: "ct_st_scheduled_schedule_id_unique" }, ConstraintRow { constraint_id: 10, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(1), constraint_name: "ct_st_scheduled_table_id_unique" }, + ConstraintRow { constraint_id: 11, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(1), constraint_name: "ct_st_row_level_security_sql_unique" }, ])); // Verify we get back the tables correctly with the proper ids... @@ -1823,6 +1833,8 @@ mod tests { IndexRow { id: 8, table: ST_VAR_ID.into(), col: col(0), name: "idx_st_var_name_unique", }, IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", }, IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", }, + IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_btree_table_id"}, + IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_sql_unique"}, IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx", }, IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx", }, IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "age_idx", }, @@ -1876,6 +1888,8 @@ mod tests { IndexRow { id: 8, table: ST_VAR_ID.into(), col: col(0), name: "idx_st_var_name_unique", }, IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", }, IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", }, + IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_btree_table_id"}, + IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_sql_unique"}, IndexRow { id: seq_start , table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx" }, IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx" }, IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "age_idx" }, @@ -1930,6 +1944,8 @@ mod tests { IndexRow { id: 8, table: ST_VAR_ID.into(), col: col(0), name: "idx_st_var_name_unique", }, IndexRow { id: 9, table: ST_SCHEDULED_ID.into(), col: col(0), name: "idx_st_scheduled_schedule_id_unique", }, IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "idx_st_scheduled_table_id_unique", }, + IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "idx_st_row_level_security_btree_table_id"}, + IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "idx_st_row_level_security_sql_unique"}, IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "id_idx" }, IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "name_idx" }, ].map(Into::into)); @@ -2019,6 +2035,32 @@ mod tests { Ok(()) } + #[test] + fn test_row_level_security() -> ResultTest<()> { + let (_, mut tx, table_id) = setup_table()?; + + let rls = RowLevelSecuritySchema { + sql: "SELECT * FROM bar".into(), + table_id, + }; + let ctx = ExecutionContext::default(); + tx.create_row_level_security(&ctx, rls.clone())?; + + let result = tx.row_level_security_for_table_id(&ctx, table_id)?; + assert_eq!( + result, + vec![RowLevelSecuritySchema { + sql: "SELECT * FROM bar".into(), + table_id, + }] + ); + + tx.drop_row_level_security(&ctx, rls.sql)?; + assert_eq!(tx.row_level_security_for_table_id(&ctx, table_id)?, []); + + Ok(()) + } + // TODO: Add the following tests // - Create index with unique constraint and immediately insert a row that violates the constraint before committing. // - Create a tx that inserts 2000 rows with an auto_inc column diff --git a/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs b/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs index be9c55aa23..dc9dc5e4ef 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/mut_tx.rs @@ -7,6 +7,7 @@ use super::{ tx_state::{DeleteTable, IndexIdMap, TxState}, SharedMutexGuard, SharedWriteGuard, }; +use crate::db::datastore::system_tables::{StRowLevelSecurityFields, StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID}; use crate::db::datastore::{ system_tables::{ StColumnFields, StColumnRow, StConstraintFields, StConstraintRow, StFields as _, StIndexFields, StIndexRow, @@ -22,6 +23,7 @@ use crate::{ use core::ops::RangeBounds; use core::{iter, ops::Bound}; use smallvec::SmallVec; +use spacetimedb_lib::db::raw_def::v9::RawSql; use spacetimedb_lib::{ address::Address, bsatn::Deserializer, @@ -36,7 +38,7 @@ use spacetimedb_sats::{ }; use spacetimedb_schema::{ def::{BTreeAlgorithm, IndexAlgorithm}, - schema::{ConstraintSchema, IndexSchema, SequenceSchema, TableSchema}, + schema::{ConstraintSchema, IndexSchema, RowLevelSecuritySchema, SequenceSchema, TableSchema}, }; use spacetimedb_table::{ blob_store::{BlobStore, HashMapBlobStore}, @@ -946,6 +948,92 @@ impl MutTxId { }) } + /// Create a row level security policy. + /// + /// Requires: + /// - `row_level_security_schema.table_id != TableId::SENTINEL` + /// - `row_level_security_schema.sql` must be unique. + /// + /// Ensures: + /// + /// - The row level security policy metadata is inserted into the system tables (and other data structures reflecting them). + /// - The returned `sql` is unique. + pub fn create_row_level_security( + &mut self, + ctx: &ExecutionContext, + row_level_security_schema: RowLevelSecuritySchema, + ) -> Result { + if row_level_security_schema.table_id == TableId::SENTINEL { + return Err(anyhow::anyhow!( + "`table_id` must not be `TableId::SENTINEL` in `{:#?}`", + row_level_security_schema + ) + .into()); + } + + log::trace!( + "ROW LEVEL SECURITY CREATING for table: {}", + row_level_security_schema.table_id + ); + + // Insert the row into st_row_level_security + // NOTE: Because st_row_level_security has a unique index on sql, this will + // fail if already exists. + let row = StRowLevelSecurityRow { + table_id: row_level_security_schema.table_id, + sql: row_level_security_schema.sql, + }; + + let row = self.insert(ST_ROW_LEVEL_SECURITY_ID, &mut ProductValue::from(row), ctx.database())?; + let row_level_security_sql = row.1.collapse().read_col(StRowLevelSecurityFields::Sql)?; + let existed = matches!(row.1, RowRefInsertion::Existed(_)); + + // Add the row level security to the transaction's insert table. + self.get_or_create_insert_table_mut(row_level_security_schema.table_id)?; + + if existed { + log::trace!("ROW LEVEL SECURITY ALREADY EXISTS: {row_level_security_sql}"); + } else { + log::trace!("ROW LEVEL SECURITY CREATED: {row_level_security_sql}"); + } + + Ok(row_level_security_sql) + } + + pub fn row_level_security_for_table_id( + &self, + ctx: &ExecutionContext, + table_id: TableId, + ) -> Result> { + Ok(self + .iter_by_col_eq( + ctx, + ST_ROW_LEVEL_SECURITY_ID, + StRowLevelSecurityFields::TableId, + &table_id.into(), + )? + .map(|row| { + let row = StRowLevelSecurityRow::try_from(row).unwrap(); + row.into() + }) + .collect()) + } + + pub fn drop_row_level_security(&mut self, ctx: &ExecutionContext, sql: RawSql) -> Result<()> { + let st_rls_ref = self + .iter_by_col_eq( + ctx, + ST_ROW_LEVEL_SECURITY_ID, + StRowLevelSecurityFields::Sql, + &sql.clone().into(), + )? + .next() + .ok_or_else(|| TableError::RawSqlNotFound(SystemTable::st_row_level_security, sql))?; + self.delete(ST_ROW_LEVEL_SECURITY_ID, st_rls_ref.pointer())?; + + Ok(()) + } + // TODO(perf, deep-integration): // When all of [`Table::read_row`], [`RowRef::new`], [`CommittedState::get`] // and [`TxState::get`] become unsafe, diff --git a/crates/core/src/db/datastore/system_tables.rs b/crates/core/src/db/datastore/system_tables.rs index bb2f03b6ad..0318870136 100644 --- a/crates/core/src/db/datastore/system_tables.rs +++ b/crates/core/src/db/datastore/system_tables.rs @@ -16,6 +16,7 @@ use crate::error::DBError; use crate::execution_context::ExecutionContext; use derive_more::From; use spacetimedb_lib::db::auth::{StAccess, StTableType}; +use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawSql}; use spacetimedb_lib::db::raw_def::*; use spacetimedb_lib::de::{Deserialize, DeserializeOwned, Error}; use spacetimedb_lib::ser::Serialize; @@ -30,7 +31,8 @@ use spacetimedb_sats::{ }; use spacetimedb_schema::def::{BTreeAlgorithm, ConstraintData, IndexAlgorithm, ModuleDef, UniqueConstraintData}; use spacetimedb_schema::schema::{ - ColumnSchema, ConstraintSchema, IndexSchema, ScheduleSchema, Schema, SequenceSchema, TableSchema, + ColumnSchema, ConstraintSchema, IndexSchema, RowLevelSecuritySchema, ScheduleSchema, Schema, SequenceSchema, + TableSchema, }; use spacetimedb_table::table::RowRef; use spacetimedb_vm::errors::{ErrorType, ErrorVm}; @@ -63,6 +65,8 @@ pub(crate) const ST_VAR_ID: TableId = TableId(8); /// The static ID of the table that defines scheduled tables pub(crate) const ST_SCHEDULED_ID: TableId = TableId(9); +/// The static ID of the table that defines the row level security (RLS) policies +pub(crate) const ST_ROW_LEVEL_SECURITY_ID: TableId = TableId(10); pub(crate) const ST_TABLE_NAME: &str = "st_table"; pub(crate) const ST_COLUMN_NAME: &str = "st_column"; pub(crate) const ST_SEQUENCE_NAME: &str = "st_sequence"; @@ -72,7 +76,7 @@ pub(crate) const ST_MODULE_NAME: &str = "st_module"; pub(crate) const ST_CLIENT_NAME: &str = "st_client"; pub(crate) const ST_SCHEDULED_NAME: &str = "st_scheduled"; pub(crate) const ST_VAR_NAME: &str = "st_var"; - +pub(crate) const ST_ROW_LEVEL_SECURITY_NAME: &str = "st_row_level_security"; /// Reserved range of sequence values used for system tables. /// /// Ids for user-created tables will start at `ST_RESERVED_SEQUENCE_RANGE + 1`. @@ -97,10 +101,12 @@ pub enum SystemTable { st_sequence, st_index, st_constraint, + st_row_level_security, } -pub(crate) fn system_tables() -> [TableSchema; 9] { +pub(crate) fn system_tables() -> [TableSchema; 10] { [ + // The order should match the `id` of the system table, that start with [ST_TABLE_IDX]. st_table_schema(), st_column_schema(), st_index_schema(), @@ -109,6 +115,7 @@ pub(crate) fn system_tables() -> [TableSchema; 9] { st_client_schema(), st_var_schema(), st_scheduled_schema(), + st_row_level_security_schema(), // Is important this is always last, so the starting sequence for each // system table is correct. st_sequence_schema(), @@ -148,7 +155,9 @@ pub(crate) const ST_MODULE_IDX: usize = 4; pub(crate) const ST_CLIENT_IDX: usize = 5; pub(crate) const ST_VAR_IDX: usize = 6; pub(crate) const ST_SCHEDULED_IDX: usize = 7; -pub(crate) const ST_SEQUENCE_IDX: usize = 8; +pub(crate) const ST_ROW_LEVEL_SECURITY_IDX: usize = 8; +// Must be the last index in the array. +pub(crate) const ST_SEQUENCE_IDX: usize = 9; macro_rules! st_fields_enum { ($(#[$attr:meta])* enum $ty_name:ident { $($name:expr, $var:ident = $discr:expr,)* }) => { @@ -228,6 +237,11 @@ st_fields_enum!(enum StConstraintFields { "constraint_data", ConstraintData = 3, }); // WARNING: For a stable schema, don't change the field names and discriminants. +st_fields_enum!(enum StRowLevelSecurityFields { + "table_id", TableId = 0, + "sql", Sql = 1, +}); +// WARNING: For a stable schema, don't change the field names and discriminants. st_fields_enum!(enum StModuleFields { "database_address", DatabaseAddress = 0, "owner_identity", OwnerIdentity = 1, @@ -311,6 +325,23 @@ fn system_module_def() -> ModuleDef { .with_auto_inc_primary_key(StConstraintFields::ConstraintId); // TODO(1.0): unique constraint on name? + let st_row_level_security_type = builder.add_type::(); + builder + .build_table( + ST_ROW_LEVEL_SECURITY_NAME, + *st_row_level_security_type.as_ref().expect("should be ref"), + ) + .with_type(TableType::System) + .with_primary_key(StRowLevelSecurityFields::Sql) + .with_unique_constraint(StRowLevelSecurityFields::Sql, None) + .with_index( + RawIndexAlgorithm::BTree { + columns: StRowLevelSecurityFields::TableId.into(), + }, + "accessor_name_doesnt_matter", + None, + ); + let st_module_type = builder.add_type::(); builder .build_table(ST_MODULE_NAME, *st_module_type.as_ref().expect("should be ref")) @@ -348,6 +379,7 @@ fn system_module_def() -> ModuleDef { validate_system_table::(&result, ST_INDEX_NAME); validate_system_table::(&result, ST_SEQUENCE_NAME); validate_system_table::(&result, ST_CONSTRAINT_NAME); + validate_system_table::(&result, ST_ROW_LEVEL_SECURITY_NAME); validate_system_table::(&result, ST_MODULE_NAME); validate_system_table::(&result, ST_CLIENT_NAME); validate_system_table::(&result, ST_VAR_NAME); @@ -400,6 +432,10 @@ fn st_constraint_schema() -> TableSchema { st_schema(ST_CONSTRAINT_NAME, ST_CONSTRAINT_ID) } +fn st_row_level_security_schema() -> TableSchema { + st_schema(ST_ROW_LEVEL_SECURITY_NAME, ST_ROW_LEVEL_SECURITY_ID) +} + pub(crate) fn st_module_schema() -> TableSchema { st_schema(ST_MODULE_NAME, ST_MODULE_ID) } @@ -429,6 +465,7 @@ pub(crate) fn system_table_schema(table_id: TableId) -> Option { ST_SEQUENCE_ID => Some(st_sequence_schema()), ST_INDEX_ID => Some(st_index_schema()), ST_CONSTRAINT_ID => Some(st_constraint_schema()), + ST_ROW_LEVEL_SECURITY_ID => Some(st_row_level_security_schema()), ST_MODULE_ID => Some(st_module_schema()), ST_CLIENT_ID => Some(st_client_schema()), ST_VAR_ID => Some(st_var_schema()), @@ -715,6 +752,39 @@ impl From for ConstraintSchema { } } +/// System Table [ST_ROW_LEVEL_SECURITY_NAME] +/// +/// | table_id | sql | +/// |----------|--------------| +/// | 1 | "SELECT ..." | +#[derive(Debug, Clone, PartialEq, Eq, SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct StRowLevelSecurityRow { + pub(crate) table_id: TableId, + pub(crate) sql: RawSql, +} + +impl TryFrom> for StRowLevelSecurityRow { + type Error = DBError; + fn try_from(row: RowRef<'_>) -> Result { + read_via_bsatn(row) + } +} + +impl From for ProductValue { + fn from(x: StRowLevelSecurityRow) -> Self { + to_product_value(&x) + } +} + +impl From for RowLevelSecuritySchema { + fn from(x: StRowLevelSecurityRow) -> Self { + Self { + table_id: x.table_id, + sql: x.sql, + } + } +} /// Indicates the kind of module the `program_bytes` of a [`StModuleRow`] /// describes. /// diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index 874f8af231..ef493c0aa8 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -27,12 +27,12 @@ use spacetimedb_commitlog as commitlog; use spacetimedb_durability::{self as durability, Durability, TxOffset}; use spacetimedb_lib::address::Address; use spacetimedb_lib::db::auth::StAccess; -use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawModuleDefV9Builder}; +use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawModuleDefV9Builder, RawSql}; use spacetimedb_lib::Identity; use spacetimedb_primitives::*; use spacetimedb_sats::{AlgebraicType, AlgebraicValue, ProductType, ProductValue}; use spacetimedb_schema::def::{ModuleDef, TableDef}; -use spacetimedb_schema::schema::{IndexSchema, Schema, SequenceSchema, TableSchema}; +use spacetimedb_schema::schema::{IndexSchema, RowLevelSecuritySchema, Schema, SequenceSchema, TableSchema}; use spacetimedb_snapshot::{SnapshotError, SnapshotRepository}; use spacetimedb_table::indexes::RowPointer; use spacetimedb_table::table::RowRef; @@ -1008,6 +1008,29 @@ impl RelationalDB { self.inner.drop_index_mut_tx(tx, index_id) } + pub fn create_row_level_security( + &self, + tx: &mut MutTx, + row_level_security_schema: RowLevelSecuritySchema, + ) -> Result { + let ctx = &ExecutionContext::internal(self.inner.database_address); + tx.create_row_level_security(ctx, row_level_security_schema) + } + + pub fn drop_row_level_security(&self, tx: &mut MutTx, sql: RawSql) -> Result<(), DBError> { + let ctx = &ExecutionContext::internal(self.inner.database_address); + tx.drop_row_level_security(ctx, sql) + } + + pub fn row_level_security_for_table_id_mut_tx( + &self, + tx: &mut MutTx, + table_id: TableId, + ) -> Result, DBError> { + let ctx = &ExecutionContext::internal(self.inner.database_address); + tx.row_level_security_for_table_id(ctx, table_id) + } + /// Returns an iterator, /// yielding every row in the table identified by `table_id`. pub fn iter_mut<'a>( @@ -1506,6 +1529,7 @@ mod tests { use spacetimedb_sats::bsatn; use spacetimedb_sats::buffer::BufReader; use spacetimedb_sats::product; + use spacetimedb_schema::schema::RowLevelSecuritySchema; use spacetimedb_table::read_column::ReadColumn; use spacetimedb_table::table::RowRef; @@ -1879,6 +1903,38 @@ mod tests { Ok(()) } + // Because we don't create `rls` when first creating the database, check we pass the bootstrap + #[test] + fn test_row_level_reopen() -> ResultTest<()> { + let stdb = TestDB::durable()?; + let mut tx = stdb.begin_mut_tx(IsolationLevel::Serializable); + let ctx = ExecutionContext::default(); + + let schema = my_table(AlgebraicType::I64); + let table_id = stdb.create_table(&mut tx, schema)?; + + let rls = RowLevelSecuritySchema { + sql: "SELECT * FROM bar".into(), + table_id, + }; + + tx.create_row_level_security(&ctx, rls)?; + stdb.commit_tx(&ctx, tx)?; + + let stdb = stdb.reopen()?; + let tx = stdb.begin_mut_tx(IsolationLevel::Serializable); + + assert_eq!( + tx.row_level_security_for_table_id(&ctx, table_id)?, + vec![RowLevelSecuritySchema { + sql: "SELECT * FROM bar".into(), + table_id, + }] + ); + + Ok(()) + } + #[test] fn test_unique() -> ResultTest<()> { let stdb = TestDB::durable()?; diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index b65b0e4b23..8721bed5c8 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -17,6 +17,7 @@ use crate::db::datastore::system_tables::SystemTable; use crate::host::scheduler::ScheduleError; use spacetimedb_lib::buffer::DecodeError; use spacetimedb_lib::db::error::{LibError, RelationError, SchemaErrors}; +use spacetimedb_lib::db::raw_def::v9::RawSql; use spacetimedb_lib::db::raw_def::RawIndexDefV8; use spacetimedb_lib::relation::FieldName; use spacetimedb_lib::ProductValue; @@ -37,6 +38,8 @@ pub enum TableError { NotFound(String), #[error("Table with ID `{1}` not found in `{0}`.")] IdNotFound(SystemTable, u32), + #[error("Sql `{1}` not found in `{0}`.")] + RawSqlNotFound(SystemTable, RawSql), #[error("Table with ID `{0}` not found in `TxState`.")] IdNotFoundState(TableId), #[error("Column `{0}.{1}` is missing a name")] diff --git a/crates/lib/src/db/raw_def/v9.rs b/crates/lib/src/db/raw_def/v9.rs index 39b24f3d45..88251ae56c 100644 --- a/crates/lib/src/db/raw_def/v9.rs +++ b/crates/lib/src/db/raw_def/v9.rs @@ -25,6 +25,9 @@ use crate::db::auth::StTableType; /// A not-yet-validated identifier. pub type RawIdentifier = Box; +/// A not-yet-validated `sql`. +pub type RawSql = Box; + /// A possibly-invalid raw module definition. /// /// ABI Version 9. @@ -79,6 +82,11 @@ pub struct RawModuleDefV9 { /// Miscellaneous additional module exports. pub misc_exports: Vec, + + /// Low level security definitions. + /// + /// Each definition must have a unique name. + pub row_level_security: Vec, } /// The definition of a database table. @@ -330,6 +338,15 @@ pub struct RawUniqueConstraintDataV9 { pub columns: ColList, } +/// Data for the `RLS` policy on a table. +#[derive(Debug, Clone, SpacetimeType)] +#[sats(crate = crate)] +#[cfg_attr(feature = "test", derive(PartialEq, Eq, PartialOrd, Ord))] +pub struct RawRowLevelSecurityDefV9 { + /// The `sql` expression to use for row-level security. + pub sql: RawSql, +} + /// A miscellaneous module export. #[derive(Debug, Clone, SpacetimeType)] #[sats(crate = crate)] @@ -531,6 +548,17 @@ impl RawModuleDefV9Builder { }); } + /// Add a row-level security policy to the module. + /// + /// The `sql` expression should be a valid SQL expression that will be used to filter rows. + /// + /// **NOTE**: The `sql` expression must be unique within the module. + pub fn add_row_level_security(&mut self, sql: &str) { + self.module + .row_level_security + .push(RawRowLevelSecurityDefV9 { sql: sql.into() }); + } + /// Get the typespace of the module. pub fn typespace(&self) -> &Typespace { &self.module.typespace diff --git a/crates/schema/Cargo.toml b/crates/schema/Cargo.toml index c64aaae60d..f32a7efee2 100644 --- a/crates/schema/Cargo.toml +++ b/crates/schema/Cargo.toml @@ -14,6 +14,7 @@ spacetimedb-lib.workspace = true spacetimedb-primitives.workspace = true spacetimedb-sats.workspace = true spacetimedb-data-structures.workspace = true +spacetimedb-sql-parser.workspace = true anyhow.workspace = true itertools.workspace = true diff --git a/crates/schema/src/def.rs b/crates/schema/src/def.rs index f7cc5d8c33..d68760c948 100644 --- a/crates/schema/src/def.rs +++ b/crates/schema/src/def.rs @@ -30,8 +30,8 @@ use spacetimedb_data_structures::map::HashMap; use spacetimedb_lib::db::raw_def; use spacetimedb_lib::db::raw_def::v9::{ Lifecycle, RawConstraintDataV9, RawConstraintDefV9, RawIdentifier, RawIndexAlgorithm, RawIndexDefV9, - RawModuleDefV9, RawReducerDefV9, RawScheduleDefV9, RawScopedTypeNameV9, RawSequenceDefV9, RawTableDefV9, - RawTypeDefV9, RawUniqueConstraintDataV9, TableAccess, TableType, + RawModuleDefV9, RawReducerDefV9, RawRowLevelSecurityDefV9, RawScheduleDefV9, RawScopedTypeNameV9, RawSequenceDefV9, + RawSql, RawTableDefV9, RawTypeDefV9, RawUniqueConstraintDataV9, TableAccess, TableType, }; use spacetimedb_lib::{ProductType, RawModuleDef}; use spacetimedb_primitives::{ColId, ColList, ColSet, TableId}; @@ -103,6 +103,11 @@ pub struct ModuleDef { /// A map from type defs to their names. refmap: HashMap, + + /// The row-level security policies. + /// + /// **Note**: Are only validated syntax-wise. + row_level_security_raw: HashMap, } impl ModuleDef { @@ -141,6 +146,11 @@ impl ModuleDef { self.types.values() } + /// The row-level security policies of the module definition. + pub fn row_level_security(&self) -> impl Iterator { + self.row_level_security_raw.values() + } + /// The `Typespace` used by the module. /// /// `AlgebraicTypeRef`s in the table, reducer, and type alias declarations refer to this typespace. @@ -340,6 +350,7 @@ impl From for RawModuleDefV9 { stored_in_table_def: _, typespace_for_generate: _, refmap: _, + row_level_security_raw, } = val; RawModuleDefV9 { @@ -348,6 +359,7 @@ impl From for RawModuleDefV9 { types: to_raw(types, |type_: &RawTypeDefV9| &type_.name), misc_exports: vec![], typespace, + row_level_security: row_level_security_raw.into_iter().map(|(_, def)| def).collect(), } } } @@ -688,6 +700,19 @@ impl From for ConstraintData { } } +/// Data for the `RLS` policy on a table. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct RowLevelSecurityDef { + /// The `sql` expression to use for row-level security. + pub sql: RawSql, +} + +impl From for RawRowLevelSecurityDefV9 { + fn from(val: RowLevelSecurityDef) -> Self { + RawRowLevelSecurityDefV9 { sql: val.sql } + } +} + /// Marks a table as a timer table for a scheduled reducer. #[derive(Debug, Clone, Eq, PartialEq)] #[non_exhaustive] @@ -937,6 +962,18 @@ impl ModuleDefLookup for ConstraintDef { } } +impl ModuleDefLookup for RawRowLevelSecurityDefV9 { + type Key<'a> = &'a RawSql; + + fn key(&self) -> Self::Key<'_> { + &self.sql + } + + fn lookup<'a>(module_def: &'a ModuleDef, key: Self::Key<'_>) -> Option<&'a Self> { + module_def.row_level_security_raw.get(key) + } +} + impl ModuleDefLookup for ScheduleDef { type Key<'a> = &'a Identifier; diff --git a/crates/schema/src/def/validate/v8.rs b/crates/schema/src/def/validate/v8.rs index c124e00fca..29e6a8fb9a 100644 --- a/crates/schema/src/def/validate/v8.rs +++ b/crates/schema/src/def/validate/v8.rs @@ -58,6 +58,7 @@ fn upgrade_module(def: RawModuleDefV8, extra_errors: &mut Vec) reducers, types, misc_exports: Default::default(), + row_level_security: vec![], // v8 doesn't have row-level security } } diff --git a/crates/schema/src/def/validate/v9.rs b/crates/schema/src/def/validate/v9.rs index a48f0077b6..cd9d90ac87 100644 --- a/crates/schema/src/def/validate/v9.rs +++ b/crates/schema/src/def/validate/v9.rs @@ -16,6 +16,7 @@ pub fn validate(def: RawModuleDefV9) -> Result { reducers, types, misc_exports, + row_level_security, } = def; let known_type_definitions = types.iter().map(|def| def.ty); @@ -56,6 +57,11 @@ pub fn validate(def: RawModuleDefV9) -> Result { }) .collect_all_errors(); + let row_level_security_raw = row_level_security + .into_iter() + .map(|rls| (rls.sql.clone(), rls)) + .collect(); + let mut refmap = HashMap::default(); let types = types .into_iter() @@ -99,6 +105,7 @@ pub fn validate(def: RawModuleDefV9) -> Result { typespace_for_generate, stored_in_table_def, refmap, + row_level_security_raw, }; result.generate_indexes(); diff --git a/crates/schema/src/error.rs b/crates/schema/src/error.rs index 0325106cfe..8fcb7349e3 100644 --- a/crates/schema/src/error.rs +++ b/crates/schema/src/error.rs @@ -110,6 +110,8 @@ pub enum ValidationError { }, #[error("Table name is reserved for system use: {table}")] TableNameReserved { table: Identifier }, + #[error("Row-level security invalid: `{error}`, query: `{sql}")] + InvalidRowLevelQuery { sql: String, error: String }, } /// A wrapper around an `AlgebraicType` that implements `fmt::Display`. diff --git a/crates/schema/src/schema.rs b/crates/schema/src/schema.rs index 25a1a2ab55..816540add7 100644 --- a/crates/schema/src/schema.rs +++ b/crates/schema/src/schema.rs @@ -9,6 +9,7 @@ use itertools::Itertools; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::db::error::{DefType, SchemaError}; +use spacetimedb_lib::db::raw_def::v9::RawSql; use spacetimedb_lib::db::raw_def::{generate_cols_name, RawConstraintDefV8}; use spacetimedb_lib::relation::{combine_constraints, Column, DbTable, FieldName, Header}; use spacetimedb_lib::{AlgebraicType, ProductType, ProductTypeElement}; @@ -1009,3 +1010,10 @@ impl Schema for ConstraintSchema { Ok(()) } } + +/// A struct representing the schema of a row-level security policy. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RowLevelSecuritySchema { + pub table_id: TableId, + pub sql: RawSql, +}