From ad6e4ae04ae1faaf148b219dd60d15edd7f81159 Mon Sep 17 00:00:00 2001 From: James Gilles Date: Wed, 16 Aug 2023 14:01:46 -0400 Subject: [PATCH] One-off query support in SpacetimeDB core. --- .../protobuf/client_api.proto | 34 +++++++ crates/core/src/client/client_connection.rs | 21 +++- crates/core/src/client/message_handlers.rs | 97 +++++++++++++++---- crates/core/src/client/messages.rs | 56 ++++++++++- crates/core/src/db/relational_db.rs | 18 ++++ crates/core/src/host/module_host.rs | 31 ++++++ .../src/host/wasm_common/module_host_actor.rs | 28 ++++++ crates/core/src/json/client_api.rs | 14 +++ .../subscription/module_subscription_actor.rs | 4 +- crates/core/src/subscription/query.rs | 5 +- 10 files changed, 285 insertions(+), 23 deletions(-) diff --git a/crates/client-api-messages/protobuf/client_api.proto b/crates/client-api-messages/protobuf/client_api.proto index 9f4d1bbd1ff..5014a10575d 100644 --- a/crates/client-api-messages/protobuf/client_api.proto +++ b/crates/client-api-messages/protobuf/client_api.proto @@ -21,6 +21,10 @@ message Message { IdentityToken identityToken = 5; // client -> database, register SQL queries on which to receive updates. Subscribe subscribe = 6; + // client -> database, send a one-off SQL query without establishing a subscription. + OneOffQuery oneOffQuery = 7; + // database -> client, return results to a one off SQL query. + OneOffQueryResponse oneOffQueryResponse = 8; } } @@ -191,3 +195,33 @@ message TransactionUpdate { Event event = 1; SubscriptionUpdate subscriptionUpdate = 2; } + +/// A one-off query submission. +/// +/// Query should be a "SELECT * FROM Table WHERE ...". Other types of queries will be rejected. +/// Multiple such semicolon-delimited queries are allowed. +/// +/// One-off queries are identified by a client-generated messageID. +/// To avoid data leaks, the server will NOT cache responses to messages based on UUID! +/// It also will not check for duplicate IDs. They are just a way to match responses to messages. +message OneOffQuery { + bytes messageId = 1; + string queryString = 2; +} + +/// A one-off query response. +/// Will contain either one error or multiple response rows. +/// At most one of these messages will be sent in reply to any query. +/// +/// The messageId will be identical to the one sent in the original query. +message OneOffQueryResponse { + bytes messageId = 1; + string error = 2; + repeated OneOffTable tables = 3; +} + +/// A table included as part of a one-off query. +message OneOffTable { + string tableName = 2; + repeated bytes row = 4; +} diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 25d28343d0f..2b4bcdc434f 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -6,7 +6,7 @@ use crate::worker_metrics::{CONNECTED_CLIENTS, WEBSOCKET_SENT, WEBSOCKET_SENT_MS use futures::prelude::*; use tokio::sync::mpsc; -use super::messages::ServerMessage; +use super::messages::{OneOffQueryResponseMessage, ServerMessage}; use super::{message_handlers, ClientActorId, MessageHandleError}; #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] @@ -168,4 +168,23 @@ impl ClientConnection { pub fn subscribe(&self, subscription: Subscribe) -> Result<(), NoSuchModule> { self.module.subscription().add_subscriber(self.sender(), subscription) } + + pub async fn one_off_query(&self, query: &str, message_id: &[u8]) -> Result<(), anyhow::Error> { + let result = self.module.one_off_query(self.id.identity, query.to_owned()).await; + let message_id = message_id.to_owned(); + let response = match result { + Ok(results) => OneOffQueryResponseMessage { + message_id, + error: None, + results, + }, + Err(err) => OneOffQueryResponseMessage { + message_id, + error: Some(format!("{}", err)), + results: Vec::new(), + }, + }; + self.send_message(response).await?; + Ok(()) + } } diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index 6d91f3d6183..84ee825c769 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -5,6 +5,7 @@ use crate::host::{EnergyDiff, ReducerArgs, Timestamp}; use crate::identity::Identity; use crate::protobuf::client_api::{message, FunctionCall, Message, Subscribe}; use crate::worker_metrics::{WEBSOCKET_REQUESTS, WEBSOCKET_REQUEST_MSG_SIZE}; +use base64::Engine; use bytes::Bytes; use bytestring::ByteString; use prost::Message as _; @@ -20,6 +21,8 @@ pub enum MessageHandleError { InvalidMessage, #[error(transparent)] TextDecode(#[from] serde_json::Error), + #[error(transparent)] + Base64Decode(#[from] base64::DecodeError), #[error(transparent)] Execution(#[from] MessageExecutionError), @@ -53,6 +56,10 @@ async fn handle_binary(client: &ClientConnection, message_buf: Vec) -> Resul DecodedMessage::Call { reducer, args } } Some(message::Type::Subscribe(subscription)) => DecodedMessage::Subscribe(subscription), + Some(message::Type::OneOffQuery(ref oneoff)) => DecodedMessage::OneOffQuery { + query_string: &oneoff.query_string[..], + message_id: &oneoff.message_id[..], + }, _ => return Err(MessageHandleError::InvalidMessage), }; @@ -61,27 +68,50 @@ async fn handle_binary(client: &ClientConnection, message_buf: Vec) -> Resul Ok(()) } -async fn handle_text(client: &ClientConnection, message: String) -> Result<(), MessageHandleError> { - #[derive(serde::Deserialize)] - enum Message<'a> { - #[serde(rename = "call")] - Call { - #[serde(borrow, rename = "fn")] - func: std::borrow::Cow<'a, str>, - args: &'a serde_json::value::RawValue, - }, - #[serde(rename = "subscribe")] - Subscribe { query_strings: Vec }, - } +#[derive(serde::Deserialize)] +enum RawJsonMessage<'a> { + #[serde(rename = "call")] + Call { + #[serde(borrow, rename = "fn")] + func: std::borrow::Cow<'a, str>, + args: &'a serde_json::value::RawValue, + }, + #[serde(rename = "subscribe")] + Subscribe { query_strings: Vec }, + #[serde(rename = "one_off_query")] + OneOffQuery { + #[serde(borrow)] + query_string: std::borrow::Cow<'a, str>, + + /// A base64-encoded string of bytes. + #[serde(borrow)] + message_id: std::borrow::Cow<'a, str>, + }, +} +async fn handle_text(client: &ClientConnection, message: String) -> Result<(), MessageHandleError> { let message = ByteString::from(message); - let msg = serde_json::from_str::(&message)?; + let msg = serde_json::from_str::(&message)?; + let mut message_id_ = Vec::new(); let msg = match msg { - Message::Call { ref func, args } => { + RawJsonMessage::Call { ref func, args } => { let args = ReducerArgs::Json(message.slice_ref(args.get())); DecodedMessage::Call { reducer: func, args } } - Message::Subscribe { query_strings } => DecodedMessage::Subscribe(Subscribe { query_strings }), + RawJsonMessage::Subscribe { query_strings } => DecodedMessage::Subscribe(Subscribe { query_strings }), + RawJsonMessage::OneOffQuery { + query_string: ref query, + message_id, + } => { + let _ = std::mem::replace( + &mut message_id_, + base64::engine::general_purpose::STANDARD.decode(&message_id[..])?, + ); + DecodedMessage::OneOffQuery { + query_string: &query[..], + message_id: &message_id_[..], + } + } }; msg.handle(client).await?; @@ -90,8 +120,15 @@ async fn handle_text(client: &ClientConnection, message: String) -> Result<(), M } enum DecodedMessage<'a> { - Call { reducer: &'a str, args: ReducerArgs }, + Call { + reducer: &'a str, + args: ReducerArgs, + }, Subscribe(Subscribe), + OneOffQuery { + query_string: &'a str, + message_id: &'a [u8], + }, } impl DecodedMessage<'_> { @@ -102,6 +139,10 @@ impl DecodedMessage<'_> { res.map(drop).map_err(|e| (Some(reducer), e.into())) } DecodedMessage::Subscribe(subscription) => client.subscribe(subscription).map_err(|e| (None, e.into())), + DecodedMessage::OneOffQuery { + query_string: query, + message_id, + } => client.one_off_query(query, message_id).await.map_err(|err| (None, err)), }; res.map_err(|(reducer, err)| MessageExecutionError { reducer: reducer.map(str::to_owned), @@ -111,7 +152,7 @@ impl DecodedMessage<'_> { } } -/// An error that arises from +/// An error that arises from executing a message. #[derive(thiserror::Error, Debug)] #[error("error executing message (reducer: {reducer:?}) (err: {err:?})")] pub struct MessageExecutionError { @@ -154,3 +195,25 @@ impl ServerMessage for MessageExecutionError { .serialize_binary() } } + +#[cfg(test)] +mod tests { + use super::RawJsonMessage; + + #[test] + fn parse_one_off_query() { + let message = r#"{ "one_off_query": { "message_id": "ywS3WFquDECZQ0UdLZN1IA==", "query_string": "SELECT * FROM User WHERE name != 'bananas'" } }"#; + let parsed = serde_json::from_str::(message).unwrap(); + + if let RawJsonMessage::OneOffQuery { + query_string: query, + message_id, + } = parsed + { + assert_eq!(query, "SELECT * FROM User WHERE name != 'bananas'"); + assert_eq!(message_id, "ywS3WFquDECZQ0UdLZN1IA=="); + } else { + panic!("wrong variant") + } + } +} diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index e35f9ac4ecc..bbbbc55ef5f 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -1,8 +1,14 @@ +use base64::Engine; use prost::Message as _; +use spacetimedb_client_api_messages::client_api::{OneOffQueryResponse, OneOffTable}; +use spacetimedb_lib::relation::MemTable; use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent}; use crate::identity::Identity; -use crate::json::client_api::{EventJson, FunctionCallJson, IdentityTokenJson, MessageJson, TransactionUpdateJson}; +use crate::json::client_api::{ + EventJson, FunctionCallJson, IdentityTokenJson, MessageJson, OneOffQueryResponseJson, OneOffTableJson, + TransactionUpdateJson, +}; use crate::protobuf::client_api::{event, message, Event, FunctionCall, IdentityToken, Message, TransactionUpdate}; use super::{DataMessage, Protocol}; @@ -177,3 +183,51 @@ where } } } + +pub struct OneOffQueryResponseMessage { + pub message_id: Vec, + pub error: Option, + pub results: Vec, +} + +impl ServerMessage for OneOffQueryResponseMessage { + fn serialize_text(self) -> MessageJson { + MessageJson::OneOffQueryResponse(OneOffQueryResponseJson { + message_id_base64: base64::engine::general_purpose::STANDARD.encode(self.message_id), + error: self.error, + result: self + .results + .into_iter() + .map(|table| OneOffTableJson { + table_name: table.head.table_name, + rows: table.data.into_iter().map(|row| row.elements).collect(), + }) + .collect(), + }) + } + + fn serialize_binary(self) -> Message { + Message { + r#type: Some(message::Type::OneOffQueryResponse(OneOffQueryResponse { + message_id: self.message_id, + error: self.error.unwrap_or_default(), + tables: self + .results + .into_iter() + .map(|table| OneOffTable { + table_name: table.head.table_name, + row: table + .data + .into_iter() + .map(|row| { + let mut row_bytes = Vec::new(); + row.encode(&mut row_bytes); + row_bytes + }) + .collect(), + }) + .collect(), + })), + } + } +} diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index d7f5b58df6d..527677ac1b8 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -267,6 +267,24 @@ impl RelationalDB { self.finish_tx(tx, res) } + /// Run a fallible function in a transaction. + /// + /// This is similar to `with_auto_commit`, but regardless of the return value of + /// the fallible function, the transaction will ALWAYS be rolled back. This can be used to + /// emulate a read-only transaction. + /// + /// TODO(jgilles): when we support actual read-only transactions, use those here instead. + pub fn with_read_only(&self, f: F) -> Result + where + F: FnOnce(&mut MutTxId) -> Result, + E: From, + { + let mut tx = self.begin_tx(); + let res = f(&mut tx); + self.rollback_tx(tx); + res + } + /// Perform the transactional logic for the `tx` according to the `res` pub fn finish_tx(&self, tx: MutTxId, res: Result) -> Result where diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 1ee83b7f31a..bb1f1ae4f96 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -11,6 +11,7 @@ use crate::protobuf::client_api::{table_row_operation, SubscriptionUpdate, Table use crate::subscription::module_subscription_actor::ModuleSubscriptionManager; use base64::{engine::general_purpose::STANDARD as BASE_64_STD, Engine as _}; use indexmap::IndexMap; +use spacetimedb_lib::relation::MemTable; use spacetimedb_lib::{ReducerDef, TableDef}; use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; use std::collections::HashMap; @@ -213,6 +214,11 @@ enum ModuleHostCommand { log_level: LogLevel, message: String, }, + OneOffQuery { + caller_identity: Identity, + query: String, + respond_to: oneshot::Sender, DBError>>, + }, } impl ModuleHostCommand { @@ -237,6 +243,11 @@ impl ModuleHostCommand { log_level, message, } => actor.inject_logs(respond_to, log_level, message), + ModuleHostCommand::OneOffQuery { + caller_identity, + query, + respond_to, + } => actor.one_off_query(caller_identity, respond_to, query), } } } @@ -272,6 +283,12 @@ pub trait ModuleHostActor: Send + 'static { fn init_database(&mut self, args: ArgsTuple, respond_to: oneshot::Sender>); fn update_database(&mut self, respond_to: oneshot::Sender>); fn inject_logs(&self, respond_to: oneshot::Sender<()>, log_level: LogLevel, message: String); + fn one_off_query( + &self, + caller_identity: Identity, + respond_to: oneshot::Sender, DBError>>, + query: String, + ); fn close(self); } @@ -484,6 +501,20 @@ impl ModuleHost { .await } + pub async fn one_off_query( + &self, + caller_identity: Identity, + query: String, + ) -> Result, anyhow::Error> { + Ok(self + .call(|respond_to| ModuleHostCommand::OneOffQuery { + caller_identity, + query, + respond_to, + }) + .await??) + } + pub fn downgrade(&self) -> WeakModuleHost { WeakModuleHost { info: self.info.clone(), diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 9beab528699..84b44fceb22 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -5,11 +5,14 @@ use std::time::{Duration, Instant}; use crate::db::datastore::traits::{ColumnDef, IndexDef, TableDef, TableSchema}; use crate::host::scheduler::Scheduler; +use crate::sql; use anyhow::Context; use bytes::Bytes; use parking_lot::{Condvar, Mutex}; use spacetimedb_lib::buffer::DecodeError; +use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::{bsatn, IndexType, ModuleDef}; +use spacetimedb_vm::expr::CrudExpr; use tokio::sync::oneshot; use crate::client::ClientConnectionSender; @@ -389,6 +392,31 @@ impl ModuleHostActor for WasmModuleHostActor { self.instances.seed().scheduler.close(); self.instances.join() } + + fn one_off_query( + &self, + caller_identity: Identity, + respond_to: oneshot::Sender, DBError>>, + query: String, + ) { + let db = &self.instances.shared.seed.worker_database_instance.relational_db; + let auth = AuthCtx::new(self.instances.shared.seed.info.identity, caller_identity); + // TODO(jgilles): make this a read-only TX when those get added + + let result = db.with_read_only(|tx| { + // NOTE(jgilles): this returns errors about mutating queries as SubscriptionErrors, which is perhaps + // mildly confusing, since the user did not subscribe to anything. Should we rename SubscriptionError to ReadOnlyQueryError? + let compiled = crate::subscription::query::compile_read_only_query(db, tx, &query)?; + + sql::execute::execute_sql( + db, + tx, + compiled.queries.into_iter().map(CrudExpr::Query).collect(), + auth, + ) + }); + let _ = respond_to.send(result); + } } /// Somewhat ad-hoc wrapper around [`DatabaseLogger`] which allows to inject diff --git a/crates/core/src/json/client_api.rs b/crates/core/src/json/client_api.rs index cd1dfc237d7..9371ec0334b 100644 --- a/crates/core/src/json/client_api.rs +++ b/crates/core/src/json/client_api.rs @@ -34,6 +34,7 @@ pub enum MessageJson { Event(EventJson), TransactionUpdate(TransactionUpdateJson), IdentityToken(IdentityTokenJson), + OneOffQueryResponse(OneOffQueryResponseJson), } impl MessageJson { @@ -98,3 +99,16 @@ pub struct StmtResultJson { #[serde_as(as = "Vec>")] pub rows: Vec>, } + +#[derive(Debug, Clone, Serialize)] +pub struct OneOffQueryResponseJson { + pub message_id_base64: String, + pub error: Option, + pub result: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub struct OneOffTableJson { + pub table_name: String, + pub rows: Vec>, +} diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 44a10b10304..f2cbe2439e0 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::{ - query::compile_query, + query::compile_read_only_query, subscription::{QuerySet, Subscription}, }; use crate::db::datastore::locking_tx_datastore::MutTxId; @@ -149,7 +149,7 @@ impl ModuleSubscriptionActor { let queries: QuerySet = subscription .query_strings .into_iter() - .map(|query| compile_query(&self.relational_db, tx, &query)) + .map(|query| compile_read_only_query(&self.relational_db, tx, &query)) .collect::>()?; let sub = match self.subscriptions.iter_mut().find(|s| s.queries == queries) { diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 2550ed30e93..019efd18254 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -81,7 +81,8 @@ pub(crate) fn run_query( } #[tracing::instrument(skip(relational_db, tx))] -pub fn compile_query(relational_db: &RelationalDB, tx: &MutTxId, input: &str) -> Result { +/// Compile a query, rejecting empty queries and queries that attempt to modify the data in any way. +pub fn compile_read_only_query(relational_db: &RelationalDB, tx: &MutTxId, input: &str) -> Result { let input = input.trim(); if input.is_empty() { return Err(SubscriptionError::Empty.into()); @@ -501,7 +502,7 @@ mod tests { run(&db, &mut tx, sql_create, AuthCtx::for_testing())?; let sql_query = "SELECT * FROM MobileEntityState JOIN EnemyState ON MobileEntityState.entity_id = EnemyState.entity_id WHERE location_x > 96000 AND MobileEntityState.location_x < 192000 AND MobileEntityState.location_z > 96000 AND MobileEntityState.location_z < 192000"; - let q = compile_query(&db, &tx, sql_query)?; + let q = compile_read_only_query(&db, &tx, sql_query)?; for q in q.queries { assert_eq!(