From 266d6fa150e4bb9af33954abc46f40fc854c8704 Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Thu, 3 Oct 2024 22:14:47 +0200 Subject: [PATCH] Add gzip and none compression algos and let the sdk pick the compression --- Cargo.lock | 1 + crates/bench/benches/subscription.rs | 12 ++- crates/client-api-messages/Cargo.toml | 1 + crates/client-api-messages/src/websocket.rs | 79 ++++++++++++++----- crates/client-api/src/routes/subscribe.rs | 15 +++- crates/core/src/client/client_connection.rs | 6 +- crates/core/src/client/messages.rs | 32 +++++--- crates/core/src/host/module_host.rs | 6 +- .../core/src/subscription/execution_unit.rs | 6 +- .../subscription/module_subscription_actor.rs | 20 +++-- .../module_subscription_manager.rs | 7 +- crates/core/src/subscription/query.rs | 9 +-- crates/core/src/subscription/subscription.rs | 5 +- crates/sdk/examples/quickstart-chat/main.rs | 2 + crates/sdk/src/db_connection.rs | 16 +++- crates/sdk/src/websocket.rs | 19 ++++- 16 files changed, 173 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4b98c6b2d..b8a3c14fc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4422,6 +4422,7 @@ dependencies = [ "chrono", "derive_more", "enum-as-inner", + "flate2", "hex", "itertools 0.12.1", "proptest", diff --git a/crates/bench/benches/subscription.rs b/crates/bench/benches/subscription.rs index 74b098e050..84b7c4b911 100644 --- a/crates/bench/benches/subscription.rs +++ b/crates/bench/benches/subscription.rs @@ -1,5 +1,4 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use spacetimedb::db::relational_db::RelationalDB; use spacetimedb::error::DBError; use spacetimedb::execution_context::ExecutionContext; use spacetimedb::host::module_host::DatabaseTableUpdate; @@ -7,6 +6,7 @@ use spacetimedb::identity::AuthCtx; use spacetimedb::messages::websocket::BsatnFormat; use spacetimedb::subscription::query::compile_read_only_query; use spacetimedb::subscription::subscription::ExecutionSet; +use spacetimedb::{db::relational_db::RelationalDB, messages::websocket::Compression}; use spacetimedb_bench::database::BenchDatabase as _; use spacetimedb_bench::spacetime_raw::SpacetimeRaw; use spacetimedb_primitives::{col_list, TableId}; @@ -104,7 +104,15 @@ fn eval(c: &mut Criterion) { let query = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap(); let query: ExecutionSet = query.into(); let ctx = &ExecutionContext::subscribe(raw.db.address()); - b.iter(|| drop(black_box(query.eval::(ctx, &raw.db, &tx, None)))) + b.iter(|| { + drop(black_box(query.eval::( + ctx, + &raw.db, + &tx, + None, + Compression::Brotli, + ))) + }) }); }; diff --git a/crates/client-api-messages/Cargo.toml b/crates/client-api-messages/Cargo.toml index 04495e838f..141a8514ab 100644 --- a/crates/client-api-messages/Cargo.toml +++ b/crates/client-api-messages/Cargo.toml @@ -15,6 +15,7 @@ bytestring.workspace = true brotli.workspace = true chrono = { workspace = true, features = ["serde"] } enum-as-inner.workspace = true +flate2.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true smallvec.workspace = true diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 33cbcd5c49..8bbc664614 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -33,7 +33,7 @@ use spacetimedb_sats::{ SpacetimeType, }; use std::{ - io::{self, Read as _}, + io::{self, Read as _, Write as _}, sync::Arc, }; @@ -74,7 +74,7 @@ pub trait WebsocketFormat: Sized { /// Convert a `QueryUpdate` into `Self::QueryUpdate`. /// This allows some formats to e.g., compress the update. - fn into_query_update(qu: QueryUpdate) -> Self::QueryUpdate; + fn into_query_update(qu: QueryUpdate, compression: Compression) -> Self::QueryUpdate; } /// Messages sent from the client to the server. @@ -165,12 +165,15 @@ pub struct OneOffQuery { pub query_string: Box, } -/// The tag recognized by ghe host and SDKs to mean no compression of a [`ServerMessage`]. +/// The tag recognized by the host and SDKs to mean no compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_NONE: u8 = 0; /// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_BROTLI: u8 = 1; +/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`]. +pub const SERVER_MSG_COMPRESSION_TAG_GZIP: u8 = 2; + /// Messages sent from the server to the client. #[derive(SpacetimeType, derive_more::From)] #[sats(crate = spacetimedb_lib)] @@ -357,13 +360,21 @@ impl TableUpdate { pub enum CompressableQueryUpdate { Uncompressed(QueryUpdate), Brotli(Bytes), + Gzip(Bytes), } impl CompressableQueryUpdate { pub fn maybe_decompress(self) -> QueryUpdate { match self { Self::Uncompressed(qu) => qu, - Self::Brotli(bytes) => brotli_decompress_qu(&bytes), + Self::Brotli(bytes) => { + let bytes = brotli_decompress(&bytes).unwrap(); + bsatn::from_slice(&bytes).unwrap() + } + Self::Gzip(bytes) => { + let bytes = gzip_decompress(&bytes).unwrap(); + bsatn::from_slice(&bytes).unwrap() + } } } } @@ -456,7 +467,7 @@ impl WebsocketFormat for JsonFormat { type QueryUpdate = QueryUpdate; - fn into_query_update(qu: QueryUpdate) -> Self::QueryUpdate { + fn into_query_update(qu: QueryUpdate, _: Compression) -> Self::QueryUpdate { qu } } @@ -499,27 +510,50 @@ impl WebsocketFormat for BsatnFormat { type QueryUpdate = CompressableQueryUpdate; - fn into_query_update(qu: QueryUpdate) -> Self::QueryUpdate { + fn into_query_update(qu: QueryUpdate, compression: Compression) -> Self::QueryUpdate { let qu_len_would_have_been = bsatn::to_len(&qu).unwrap(); - if should_compress(qu_len_would_have_been) { - let bytes = bsatn::to_vec(&qu).unwrap(); - let mut out = Vec::new(); - brotli_compress(&bytes, &mut out); - CompressableQueryUpdate::Brotli(out.into()) - } else { - CompressableQueryUpdate::Uncompressed(qu) + match decide_compression(qu_len_would_have_been, compression) { + Compression::None => CompressableQueryUpdate::Uncompressed(qu), + Compression::Brotli => { + let bytes = bsatn::to_vec(&qu).unwrap(); + let mut out = Vec::new(); + brotli_compress(&bytes, &mut out); + CompressableQueryUpdate::Brotli(out.into()) + } + Compression::Gzip => { + let bytes = bsatn::to_vec(&qu).unwrap(); + let mut out = Vec::new(); + gzip_compress(&bytes, &mut out); + CompressableQueryUpdate::Gzip(out.into()) + } } } } -pub fn should_compress(len: usize) -> bool { - /// The threshold at which we start to compress messages. +/// A specification of either a desired or decided compression algorithm. +#[derive(serde::Deserialize, Default, PartialEq, Eq, Clone, Copy, Hash, Debug)] +pub enum Compression { + /// No compression ever. + None, + /// Compress using brotli if a certain size threshold was met. + #[default] + Brotli, + /// Compress using gzip if a certain size threshold was met. + Gzip, +} + +pub fn decide_compression(len: usize, compression: Compression) -> Compression { + /// The threshold beyond which we start to compress messages. /// 1KiB was chosen without measurement. /// TODO(perf): measure! const COMPRESS_THRESHOLD: usize = 1024; - len <= COMPRESS_THRESHOLD + if len > COMPRESS_THRESHOLD { + compression + } else { + Compression::None + } } pub fn brotli_compress(bytes: &[u8], out: &mut Vec) { @@ -560,9 +594,16 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result, io::Error> { Ok(decompressed) } -pub fn brotli_decompress_qu(bytes: &[u8]) -> QueryUpdate { - let bytes = brotli_decompress(bytes).unwrap(); - bsatn::from_slice(&bytes).unwrap() +pub fn gzip_compress(bytes: &[u8], out: &mut Vec) { + let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast()); + encoder.write_all(bytes).unwrap(); + encoder.finish().expect("Failed to gzip compress `bytes`"); +} + +pub fn gzip_decompress(bytes: &[u8]) -> Result, io::Error> { + let mut decompressed = Vec::new(); + let _ = flate2::read::GzDecoder::new(bytes).read(&mut decompressed)?; + Ok(decompressed) } type RowSize = u16; diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index cf45f04f92..2d723b3c4b 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -17,6 +17,7 @@ use spacetimedb::client::{ClientActorId, ClientConnection, DataMessage, MessageH use spacetimedb::host::NoSuchModule; use spacetimedb::util::also_poll; use spacetimedb::worker_metrics::WORKER_METRICS; +use spacetimedb_client_api_messages::websocket::Compression; use spacetimedb_lib::address::AddressForUrl; use spacetimedb_lib::Address; use std::time::Instant; @@ -42,6 +43,7 @@ pub struct SubscribeParams { #[derive(Deserialize)] pub struct SubscribeQueryParams { pub client_address: Option, + pub compression: Option, } // TODO: is this a reasonable way to generate client addresses? @@ -55,7 +57,10 @@ pub fn generate_random_address() -> Address { pub async fn handle_websocket( State(ctx): State, Path(SubscribeParams { name_or_address }): Path, - Query(SubscribeQueryParams { client_address }): Query, + Query(SubscribeQueryParams { + client_address, + compression, + }): Query, forwarded_for: Option>, Extension(auth): Extension, ws: WebSocketUpgrade, @@ -80,6 +85,7 @@ where ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]); let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?; + let compression = compression.unwrap_or_default(); // TODO: Should also maybe refactor the code and the protocol to allow a single websocket // to connect to multiple modules @@ -131,7 +137,8 @@ where } let actor = |client, sendrx| ws_client_actor(client, ws, sendrx); - let client = match ClientConnection::spawn(client_id, protocol, replica_id, module_rx, actor).await { + let client = match ClientConnection::spawn(client_id, protocol, compression, replica_id, module_rx, actor).await + { Ok(s) => s, Err(e) => { log::warn!("ModuleHost died while we were connecting: {e:#}"); @@ -259,7 +266,7 @@ async fn ws_client_actor_inner( let workload = msg.workload(); let num_rows = msg.num_rows(); - let msg = datamsg_to_wsmsg(serialize(msg, client.protocol)); + let msg = datamsg_to_wsmsg(serialize(msg, client.protocol, client.compression)); // These metrics should be updated together, // or not at all. @@ -347,7 +354,7 @@ async fn ws_client_actor_inner( if let Err(e) = res { if let MessageHandleError::Execution(err) = e { log::error!("{err:#}"); - let msg = serialize(err, client.protocol); + let msg = serialize(err, client.protocol, client.compression); if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await { log::warn!("Websocket send error: {error}") } diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 09ccaeac15..2d0136cf59 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -12,7 +12,7 @@ use crate::util::prometheus_handle::IntGaugeExt; use crate::worker_metrics::WORKER_METRICS; use derive_more::From; use futures::prelude::*; -use spacetimedb_client_api_messages::websocket::FormatSwitch; +use spacetimedb_client_api_messages::websocket::{Compression, FormatSwitch}; use spacetimedb_lib::identity::RequestId; use tokio::sync::{mpsc, oneshot, watch}; use tokio::task::AbortHandle; @@ -36,6 +36,7 @@ impl Protocol { pub struct ClientConnectionSender { pub id: ClientActorId, pub protocol: Protocol, + pub compression: Compression, sendtx: mpsc::Sender, abort_handle: AbortHandle, cancelled: AtomicBool, @@ -61,6 +62,7 @@ impl ClientConnectionSender { Self { id, protocol, + compression: Compression::Brotli, sendtx, abort_handle, cancelled: AtomicBool::new(false), @@ -143,6 +145,7 @@ impl ClientConnection { pub async fn spawn( id: ClientActorId, protocol: Protocol, + compression: Compression, replica_id: u64, mut module_rx: watch::Receiver, actor: F, @@ -178,6 +181,7 @@ impl ClientConnection { let sender = Arc::new(ClientConnectionSender { id, protocol, + compression, sendtx, abort_handle, cancelled: AtomicBool::new(false), diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 8497447dbe..36d5fa043f 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -5,8 +5,8 @@ use crate::host::ArgsTuple; use crate::messages::websocket as ws; use derive_more::From; use spacetimedb_client_api_messages::websocket::{ - BsatnFormat, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, - SERVER_MSG_COMPRESSION_TAG_NONE, + BsatnFormat, Compression, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, + SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, }; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; @@ -28,8 +28,13 @@ pub(super) type SwitchedServerMessage = FormatSwitch, protocol: Protocol) -> DataMessage { +/// If `protocol` is [`Protocol::Binary`], +/// the message will be conditionally compressed by this method according to `compression`. +pub fn serialize( + msg: impl ToProtocol, + protocol: Protocol, + compression: Compression, +) -> DataMessage { // TODO(centril, perf): here we are allocating buffers only to throw them away eventually. // Consider pooling these allocations so that we reuse them. match msg.to_protocol(protocol) { @@ -40,12 +45,19 @@ pub fn serialize(msg: impl ToProtocol, protocol bsatn::to_writer(&mut msg_bytes, &msg).unwrap(); // Conditionally compress the message. - let msg_bytes = if ws::should_compress(msg_bytes[1..].len()) { - let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI]; - ws::brotli_compress(&msg_bytes[1..], &mut out); - out - } else { - msg_bytes + let srv_msg = &msg_bytes[1..]; + let msg_bytes = match ws::decide_compression(srv_msg.len(), compression) { + Compression::None => msg_bytes, + Compression::Brotli => { + let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI]; + ws::brotli_compress(srv_msg, &mut out); + out + } + Compression::Gzip => { + let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP]; + ws::gzip_compress(srv_msg, &mut out); + out + } }; msg_bytes.into() } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 76d21056b6..8492fbdc5c 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -24,7 +24,7 @@ use indexmap::IndexSet; use itertools::Itertools; use smallvec::SmallVec; use spacetimedb_client_api_messages::timestamp::Timestamp; -use spacetimedb_client_api_messages::websocket::{QueryUpdate, WebsocketFormat}; +use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, WebsocketFormat}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_lib::identity::{AuthCtx, RequestId}; @@ -124,12 +124,12 @@ impl UpdatesRelValue<'_> { !(self.deletes.is_empty() && self.inserts.is_empty()) } - pub fn encode(&self) -> (F::QueryUpdate, u64) { + pub fn encode(&self, compression: Compression) -> (F::QueryUpdate, u64) { let (deletes, nr_del) = F::encode_list(self.deletes.iter()); let (inserts, nr_ins) = F::encode_list(self.inserts.iter()); let num_rows = nr_del + nr_ins; let qu = QueryUpdate { deletes, inserts }; - let cqu = F::into_query_update(qu); + let cqu = F::into_query_update(qu, compression); (cqu, num_rows) } } diff --git a/crates/core/src/subscription/execution_unit.rs b/crates/core/src/subscription/execution_unit.rs index 71634d4df8..c79b35d416 100644 --- a/crates/core/src/subscription/execution_unit.rs +++ b/crates/core/src/subscription/execution_unit.rs @@ -9,7 +9,7 @@ use crate::host::module_host::{DatabaseTableUpdate, DatabaseTableUpdateRelValue, use crate::messages::websocket::TableUpdate; use crate::util::slow::SlowQueryLogger; use crate::vm::{build_query, TxMode}; -use spacetimedb_client_api_messages::websocket::{QueryUpdate, RowListLen as _, WebsocketFormat}; +use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, RowListLen as _, WebsocketFormat}; use spacetimedb_lib::db::error::AuthError; use spacetimedb_lib::relation::DbTable; use spacetimedb_lib::{Identity, ProductValue}; @@ -209,6 +209,7 @@ impl ExecutionUnit { tx: &Tx, sql: &str, slow_query_threshold: Option, + compression: Compression, ) -> Option> { let _slow_query = SlowQueryLogger::new(sql, slow_query_threshold, ctx.workload()).log_guard(); @@ -220,7 +221,8 @@ impl ExecutionUnit { (!inserts.is_empty()).then(|| { let deletes = F::List::default(); - let update = F::into_query_update(QueryUpdate { deletes, inserts }); + let qu = QueryUpdate { deletes, inserts }; + let update = F::into_query_update(qu, compression); TableUpdate::new(self.return_table(), self.return_name(), (update, num_rows)) }) } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index bcaf5a939a..5bf97bd8e1 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -120,12 +120,20 @@ impl ModuleSubscriptions { let slow_query_threshold = StVarTable::sub_limit(&ctx, &self.relational_db, &tx)?.map(Duration::from_millis); let database_update = match sender.protocol { - Protocol::Text => { - FormatSwitch::Json(execution_set.eval(&ctx, &self.relational_db, &tx, slow_query_threshold)) - } - Protocol::Binary => { - FormatSwitch::Bsatn(execution_set.eval(&ctx, &self.relational_db, &tx, slow_query_threshold)) - } + Protocol::Text => FormatSwitch::Json(execution_set.eval( + &ctx, + &self.relational_db, + &tx, + slow_query_threshold, + sender.compression, + )), + Protocol::Binary => FormatSwitch::Bsatn(execution_set.eval( + &ctx, + &self.relational_db, + &tx, + slow_query_threshold, + sender.compression, + )), }; // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently. diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 43a0f3e48b..2ed27f812a 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -168,15 +168,16 @@ impl SubscriptionManager { let mut ops_bin: Option<(CompressableQueryUpdate, _)> = None; let mut ops_json: Option<(QueryUpdate, _)> = None; self.subscribers.get(hash).into_iter().flatten().map(move |id| { - let ops = match self.clients[id].protocol { + let client = &*self.clients[id]; + let ops = match client.protocol { Protocol::Binary => Bsatn( ops_bin - .get_or_insert_with(|| delta.updates.encode::()) + .get_or_insert_with(|| delta.updates.encode::(client.compression)) .clone(), ), Protocol::Text => Json( ops_json - .get_or_insert_with(|| delta.updates.encode::()) + .get_or_insert_with(|| delta.updates.encode::(client.compression)) .clone(), ), }; diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index a369d64a81..0f20b5b67c 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -101,7 +101,7 @@ mod tests { use crate::vm::tests::create_table_with_rows; use crate::vm::DbProgram; use itertools::Itertools; - use spacetimedb_client_api_messages::websocket::{brotli_decompress_qu, BsatnFormat, CompressableQueryUpdate}; + use spacetimedb_client_api_messages::websocket::{BsatnFormat, Compression}; use spacetimedb_lib::bsatn; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::error::ResultTest; @@ -302,7 +302,7 @@ mod tests { total_tables: usize, rows: &[ProductValue], ) -> ResultTest<()> { - let result = s.eval::(ctx, db, tx, None).tables; + let result = s.eval::(ctx, db, tx, None, Compression::Brotli).tables; assert_eq!( result.len(), total_tables, @@ -312,10 +312,7 @@ mod tests { let result = result .into_iter() .flat_map(|x| x.updates) - .map(|x| match x { - CompressableQueryUpdate::Uncompressed(qu) => qu, - CompressableQueryUpdate::Brotli(bytes) => brotli_decompress_qu(&bytes), - }) + .map(|x| x.maybe_decompress()) .flat_map(|x| { (&x.deletes) .into_iter() diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 68307ff805..3fa6bfecd4 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -32,7 +32,7 @@ use crate::vm::{build_query, TxMode}; use anyhow::Context; use itertools::Either; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -use spacetimedb_client_api_messages::websocket::WebsocketFormat; +use spacetimedb_client_api_messages::websocket::{Compression, WebsocketFormat}; use spacetimedb_data_structures::map::HashSet; use spacetimedb_lib::db::auth::{StAccess, StTableType}; use spacetimedb_lib::db::error::AuthError; @@ -521,13 +521,14 @@ impl ExecutionSet { db: &RelationalDB, tx: &Tx, slow_query_threshold: Option, + compression: Compression, ) -> ws::DatabaseUpdate { // evaluate each of the execution units in this ExecutionSet in parallel let tables = self .exec_units // if you need eval to run single-threaded for debugging, change this to .iter() .par_iter() - .filter_map(|unit| unit.eval(ctx, db, tx, &unit.sql, slow_query_threshold)) + .filter_map(|unit| unit.eval(ctx, db, tx, &unit.sql, slow_query_threshold, compression)) .collect(); ws::DatabaseUpdate { tables } } diff --git a/crates/sdk/examples/quickstart-chat/main.rs b/crates/sdk/examples/quickstart-chat/main.rs index aa88a30d99..5d5732d1bb 100644 --- a/crates/sdk/examples/quickstart-chat/main.rs +++ b/crates/sdk/examples/quickstart-chat/main.rs @@ -2,6 +2,7 @@ mod module_bindings; use module_bindings::*; +use spacetimedb_client_api_messages::websocket::Compression; use spacetimedb_sdk::{credentials, DbContext, Event, Identity, ReducerEvent, Status, Table, TableWithPrimaryKey}; // # Our main function @@ -171,6 +172,7 @@ fn connect_to_db() -> DbConnection { .with_credentials(creds_store().load().expect("Error loading credentials")) .with_module_name(DB_NAME) .with_uri(HOST) + .with_compression(Compression::Gzip) .build() .expect("Failed to connect") } diff --git a/crates/sdk/src/db_connection.rs b/crates/sdk/src/db_connection.rs index be3d10d3cd..12883f8ae2 100644 --- a/crates/sdk/src/db_connection.rs +++ b/crates/sdk/src/db_connection.rs @@ -31,7 +31,7 @@ use bytes::Bytes; use futures::StreamExt; use futures_channel::mpsc; use http::Uri; -use spacetimedb_client_api_messages::websocket::BsatnFormat; +use spacetimedb_client_api_messages::websocket::{BsatnFormat, Compression}; use spacetimedb_lib::{bsatn, de::Deserialize, ser::Serialize, Address, Identity}; use std::{ sync::{Arc, Mutex as StdMutex, OnceLock}, @@ -727,6 +727,8 @@ pub struct DbConnectionBuilder { on_connect: Option>, on_connect_error: Option, on_disconnect: Option>, + + compression: Compression, } /// This process's global client address, which will be attacked to all connections it makes. @@ -769,6 +771,7 @@ impl DbConnectionBuilder { on_connect: None, on_connect_error: None, on_disconnect: None, + compression: <_>::default(), } } @@ -815,6 +818,7 @@ but you must call one of them, or else the connection will never progress. self.module_name.as_ref().unwrap(), self.credentials.as_ref(), get_client_address(), + self.compression, )) })?; @@ -877,6 +881,16 @@ but you must call one of them, or else the connection will never progress. self } + /// Sets the compression used when a certain threshold in the message size has been reached. + /// + /// The current threshold used by the host is 1KiB for the entire server message + /// and for individual query updates. + /// Note however that this threshold is not guaranteed and may change without notice. + pub fn with_compression(mut self, compression: Compression) -> Self { + self.compression = compression; + self + } + /// Register a callback to run when the connection is successfully initiated. /// /// The callback will receive three arguments: diff --git a/crates/sdk/src/websocket.rs b/crates/sdk/src/websocket.rs index f9b0631965..8a4204f2ea 100644 --- a/crates/sdk/src/websocket.rs +++ b/crates/sdk/src/websocket.rs @@ -9,7 +9,8 @@ use futures::{SinkExt, StreamExt as _, TryStreamExt}; use futures_channel::mpsc; use http::uri::{Scheme, Uri}; use spacetimedb_client_api_messages::websocket::{ - brotli_decompress, BsatnFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_NONE, + brotli_decompress, gzip_decompress, BsatnFormat, Compression, SERVER_MSG_COMPRESSION_TAG_BROTLI, + SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, }; use spacetimedb_lib::{bsatn, Address, Identity}; use tokio::task::JoinHandle; @@ -37,7 +38,7 @@ fn parse_scheme(scheme: Option) -> Result { }) } -fn make_uri(host: Host, db_name: &str, client_address: Address) -> Result +fn make_uri(host: Host, db_name: &str, client_address: Address, compression: Compression) -> Result where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, @@ -62,6 +63,11 @@ where path.push_str(db_name); path.push_str("?client_address="); path.push_str(&client_address.to_hex()); + match compression { + Compression::None => path.push_str("&compression=None"), + Compression::Gzip => path.push_str("&compression=Gzip"), + Compression::Brotli => path.push_str("&compression=Brotli"), + }; parts.path_and_query = Some(path.parse()?); Ok(Uri::from_parts(parts)?) } @@ -80,12 +86,13 @@ fn make_request( db_name: &str, credentials: Option<&(Identity, String)>, client_address: Address, + compression: Compression, ) -> Result> where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, { - let uri = make_uri(host, db_name, client_address)?; + let uri = make_uri(host, db_name, client_address, compression)?; let mut req = IntoClientRequest::into_client_request(uri)?; request_insert_protocol_header(&mut req); request_insert_auth_header(&mut req, credentials); @@ -134,12 +141,13 @@ impl WsConnection { db_name: &str, credentials: Option<&(Identity, String)>, client_address: Address, + compression: Compression, ) -> Result where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, { - let req = make_request(host, db_name, credentials, client_address)?; + let req = make_request(host, db_name, credentials, client_address, compression)?; let (sock, _): (WebSocketStream>, _) = connect_async_with_config( req, // TODO(kim): In order to be able to replicate module WASM blobs, @@ -166,6 +174,9 @@ impl WsConnection { SERVER_MSG_COMPRESSION_TAG_BROTLI => { bsatn::from_slice(&brotli_decompress(bytes).context("Failed to Brotli decompress message")?)? } + SERVER_MSG_COMPRESSION_TAG_GZIP => { + bsatn::from_slice(&gzip_decompress(bytes).context("Failed to gzip decompress message")?)? + } c => bail!("Unknown compression format `{c}`"), }) }