Skip to content

Commit

Permalink
Add gzip and none compression algos and let the sdk pick the compression
Browse files Browse the repository at this point in the history
  • Loading branch information
Centril committed Oct 9, 2024
1 parent 4db4c9a commit 266d6fa
Show file tree
Hide file tree
Showing 16 changed files with 173 additions and 63 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions crates/bench/benches/subscription.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use spacetimedb::db::relational_db::RelationalDB;
use spacetimedb::error::DBError;
use spacetimedb::execution_context::ExecutionContext;
use spacetimedb::host::module_host::DatabaseTableUpdate;
use spacetimedb::identity::AuthCtx;
use spacetimedb::messages::websocket::BsatnFormat;
use spacetimedb::subscription::query::compile_read_only_query;
use spacetimedb::subscription::subscription::ExecutionSet;
use spacetimedb::{db::relational_db::RelationalDB, messages::websocket::Compression};
use spacetimedb_bench::database::BenchDatabase as _;
use spacetimedb_bench::spacetime_raw::SpacetimeRaw;
use spacetimedb_primitives::{col_list, TableId};
Expand Down Expand Up @@ -104,7 +104,15 @@ fn eval(c: &mut Criterion) {
let query = compile_read_only_query(&raw.db, &AuthCtx::for_testing(), &tx, sql).unwrap();
let query: ExecutionSet = query.into();
let ctx = &ExecutionContext::subscribe(raw.db.address());
b.iter(|| drop(black_box(query.eval::<BsatnFormat>(ctx, &raw.db, &tx, None))))
b.iter(|| {
drop(black_box(query.eval::<BsatnFormat>(
ctx,
&raw.db,
&tx,
None,
Compression::Brotli,
)))
})
});
};

Expand Down
1 change: 1 addition & 0 deletions crates/client-api-messages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ bytestring.workspace = true
brotli.workspace = true
chrono = { workspace = true, features = ["serde"] }
enum-as-inner.workspace = true
flate2.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
smallvec.workspace = true
Expand Down
79 changes: 60 additions & 19 deletions crates/client-api-messages/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use spacetimedb_sats::{
SpacetimeType,
};
use std::{
io::{self, Read as _},
io::{self, Read as _, Write as _},
sync::Arc,
};

Expand Down Expand Up @@ -74,7 +74,7 @@ pub trait WebsocketFormat: Sized {

/// Convert a `QueryUpdate` into `Self::QueryUpdate`.
/// This allows some formats to e.g., compress the update.
fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate;
fn into_query_update(qu: QueryUpdate<Self>, compression: Compression) -> Self::QueryUpdate;
}

/// Messages sent from the client to the server.
Expand Down Expand Up @@ -165,12 +165,15 @@ pub struct OneOffQuery {
pub query_string: Box<str>,
}

/// The tag recognized by ghe host and SDKs to mean no compression of a [`ServerMessage`].
/// The tag recognized by the host and SDKs to mean no compression of a [`ServerMessage`].
pub const SERVER_MSG_COMPRESSION_TAG_NONE: u8 = 0;

/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`].
pub const SERVER_MSG_COMPRESSION_TAG_BROTLI: u8 = 1;

/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`].
pub const SERVER_MSG_COMPRESSION_TAG_GZIP: u8 = 2;

/// Messages sent from the server to the client.
#[derive(SpacetimeType, derive_more::From)]
#[sats(crate = spacetimedb_lib)]
Expand Down Expand Up @@ -357,13 +360,21 @@ impl<F: WebsocketFormat> TableUpdate<F> {
pub enum CompressableQueryUpdate<F: WebsocketFormat> {
Uncompressed(QueryUpdate<F>),
Brotli(Bytes),
Gzip(Bytes),
}

impl CompressableQueryUpdate<BsatnFormat> {
pub fn maybe_decompress(self) -> QueryUpdate<BsatnFormat> {
match self {
Self::Uncompressed(qu) => qu,
Self::Brotli(bytes) => brotli_decompress_qu(&bytes),
Self::Brotli(bytes) => {
let bytes = brotli_decompress(&bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
}
Self::Gzip(bytes) => {
let bytes = gzip_decompress(&bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
}
}
}
}
Expand Down Expand Up @@ -456,7 +467,7 @@ impl WebsocketFormat for JsonFormat {

type QueryUpdate = QueryUpdate<Self>;

fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate {
fn into_query_update(qu: QueryUpdate<Self>, _: Compression) -> Self::QueryUpdate {
qu
}
}
Expand Down Expand Up @@ -499,27 +510,50 @@ impl WebsocketFormat for BsatnFormat {

type QueryUpdate = CompressableQueryUpdate<Self>;

fn into_query_update(qu: QueryUpdate<Self>) -> Self::QueryUpdate {
fn into_query_update(qu: QueryUpdate<Self>, compression: Compression) -> Self::QueryUpdate {
let qu_len_would_have_been = bsatn::to_len(&qu).unwrap();

if should_compress(qu_len_would_have_been) {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
brotli_compress(&bytes, &mut out);
CompressableQueryUpdate::Brotli(out.into())
} else {
CompressableQueryUpdate::Uncompressed(qu)
match decide_compression(qu_len_would_have_been, compression) {
Compression::None => CompressableQueryUpdate::Uncompressed(qu),
Compression::Brotli => {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
brotli_compress(&bytes, &mut out);
CompressableQueryUpdate::Brotli(out.into())
}
Compression::Gzip => {
let bytes = bsatn::to_vec(&qu).unwrap();
let mut out = Vec::new();
gzip_compress(&bytes, &mut out);
CompressableQueryUpdate::Gzip(out.into())
}
}
}
}

pub fn should_compress(len: usize) -> bool {
/// The threshold at which we start to compress messages.
/// A specification of either a desired or decided compression algorithm.
#[derive(serde::Deserialize, Default, PartialEq, Eq, Clone, Copy, Hash, Debug)]
pub enum Compression {
/// No compression ever.
None,
/// Compress using brotli if a certain size threshold was met.
#[default]
Brotli,
/// Compress using gzip if a certain size threshold was met.
Gzip,
}

pub fn decide_compression(len: usize, compression: Compression) -> Compression {
/// The threshold beyond which we start to compress messages.
/// 1KiB was chosen without measurement.
/// TODO(perf): measure!
const COMPRESS_THRESHOLD: usize = 1024;

len <= COMPRESS_THRESHOLD
if len > COMPRESS_THRESHOLD {
compression
} else {
Compression::None
}
}

pub fn brotli_compress(bytes: &[u8], out: &mut Vec<u8>) {
Expand Down Expand Up @@ -560,9 +594,16 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
Ok(decompressed)
}

pub fn brotli_decompress_qu(bytes: &[u8]) -> QueryUpdate<BsatnFormat> {
let bytes = brotli_decompress(bytes).unwrap();
bsatn::from_slice(&bytes).unwrap()
pub fn gzip_compress(bytes: &[u8], out: &mut Vec<u8>) {
let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast());
encoder.write_all(bytes).unwrap();
encoder.finish().expect("Failed to gzip compress `bytes`");
}

pub fn gzip_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
let mut decompressed = Vec::new();
let _ = flate2::read::GzDecoder::new(bytes).read(&mut decompressed)?;
Ok(decompressed)
}

type RowSize = u16;
Expand Down
15 changes: 11 additions & 4 deletions crates/client-api/src/routes/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use spacetimedb::client::{ClientActorId, ClientConnection, DataMessage, MessageH
use spacetimedb::host::NoSuchModule;
use spacetimedb::util::also_poll;
use spacetimedb::worker_metrics::WORKER_METRICS;
use spacetimedb_client_api_messages::websocket::Compression;
use spacetimedb_lib::address::AddressForUrl;
use spacetimedb_lib::Address;
use std::time::Instant;
Expand All @@ -42,6 +43,7 @@ pub struct SubscribeParams {
#[derive(Deserialize)]
pub struct SubscribeQueryParams {
pub client_address: Option<AddressForUrl>,
pub compression: Option<Compression>,
}

// TODO: is this a reasonable way to generate client addresses?
Expand All @@ -55,7 +57,10 @@ pub fn generate_random_address() -> Address {
pub async fn handle_websocket<S>(
State(ctx): State<S>,
Path(SubscribeParams { name_or_address }): Path<SubscribeParams>,
Query(SubscribeQueryParams { client_address }): Query<SubscribeQueryParams>,
Query(SubscribeQueryParams {
client_address,
compression,
}): Query<SubscribeQueryParams>,
forwarded_for: Option<TypedHeader<XForwardedFor>>,
Extension(auth): Extension<SpacetimeAuth>,
ws: WebSocketUpgrade,
Expand All @@ -80,6 +85,7 @@ where
ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);

let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?;
let compression = compression.unwrap_or_default();

// TODO: Should also maybe refactor the code and the protocol to allow a single websocket
// to connect to multiple modules
Expand Down Expand Up @@ -131,7 +137,8 @@ where
}

let actor = |client, sendrx| ws_client_actor(client, ws, sendrx);
let client = match ClientConnection::spawn(client_id, protocol, replica_id, module_rx, actor).await {
let client = match ClientConnection::spawn(client_id, protocol, compression, replica_id, module_rx, actor).await
{
Ok(s) => s,
Err(e) => {
log::warn!("ModuleHost died while we were connecting: {e:#}");
Expand Down Expand Up @@ -259,7 +266,7 @@ async fn ws_client_actor_inner(
let workload = msg.workload();
let num_rows = msg.num_rows();

let msg = datamsg_to_wsmsg(serialize(msg, client.protocol));
let msg = datamsg_to_wsmsg(serialize(msg, client.protocol, client.compression));

// These metrics should be updated together,
// or not at all.
Expand Down Expand Up @@ -347,7 +354,7 @@ async fn ws_client_actor_inner(
if let Err(e) = res {
if let MessageHandleError::Execution(err) = e {
log::error!("{err:#}");
let msg = serialize(err, client.protocol);
let msg = serialize(err, client.protocol, client.compression);
if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await {
log::warn!("Websocket send error: {error}")
}
Expand Down
6 changes: 5 additions & 1 deletion crates/core/src/client/client_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::util::prometheus_handle::IntGaugeExt;
use crate::worker_metrics::WORKER_METRICS;
use derive_more::From;
use futures::prelude::*;
use spacetimedb_client_api_messages::websocket::FormatSwitch;
use spacetimedb_client_api_messages::websocket::{Compression, FormatSwitch};
use spacetimedb_lib::identity::RequestId;
use tokio::sync::{mpsc, oneshot, watch};
use tokio::task::AbortHandle;
Expand All @@ -36,6 +36,7 @@ impl Protocol {
pub struct ClientConnectionSender {
pub id: ClientActorId,
pub protocol: Protocol,
pub compression: Compression,
sendtx: mpsc::Sender<SerializableMessage>,
abort_handle: AbortHandle,
cancelled: AtomicBool,
Expand All @@ -61,6 +62,7 @@ impl ClientConnectionSender {
Self {
id,
protocol,
compression: Compression::Brotli,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
Expand Down Expand Up @@ -143,6 +145,7 @@ impl ClientConnection {
pub async fn spawn<F, Fut>(
id: ClientActorId,
protocol: Protocol,
compression: Compression,
replica_id: u64,
mut module_rx: watch::Receiver<ModuleHost>,
actor: F,
Expand Down Expand Up @@ -178,6 +181,7 @@ impl ClientConnection {
let sender = Arc::new(ClientConnectionSender {
id,
protocol,
compression,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
Expand Down
32 changes: 22 additions & 10 deletions crates/core/src/client/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::host::ArgsTuple;
use crate::messages::websocket as ws;
use derive_more::From;
use spacetimedb_client_api_messages::websocket::{
BsatnFormat, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI,
SERVER_MSG_COMPRESSION_TAG_NONE,
BsatnFormat, Compression, FormatSwitch, JsonFormat, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI,
SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
};
use spacetimedb_lib::identity::RequestId;
use spacetimedb_lib::ser::serde::SerializeWrapper;
Expand All @@ -28,8 +28,13 @@ pub(super) type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnForm

/// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`].
///
/// If `protocol` is [`Protocol::Binary`], the message will be compressed by this method.
pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, protocol: Protocol) -> DataMessage {
/// If `protocol` is [`Protocol::Binary`],
/// the message will be conditionally compressed by this method according to `compression`.
pub fn serialize(
msg: impl ToProtocol<Encoded = SwitchedServerMessage>,
protocol: Protocol,
compression: Compression,
) -> DataMessage {
// TODO(centril, perf): here we are allocating buffers only to throw them away eventually.
// Consider pooling these allocations so that we reuse them.
match msg.to_protocol(protocol) {
Expand All @@ -40,12 +45,19 @@ pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, protocol
bsatn::to_writer(&mut msg_bytes, &msg).unwrap();

// Conditionally compress the message.
let msg_bytes = if ws::should_compress(msg_bytes[1..].len()) {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
ws::brotli_compress(&msg_bytes[1..], &mut out);
out
} else {
msg_bytes
let srv_msg = &msg_bytes[1..];
let msg_bytes = match ws::decide_compression(srv_msg.len(), compression) {
Compression::None => msg_bytes,
Compression::Brotli => {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
ws::brotli_compress(srv_msg, &mut out);
out
}
Compression::Gzip => {
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP];
ws::gzip_compress(srv_msg, &mut out);
out
}
};
msg_bytes.into()
}
Expand Down
6 changes: 3 additions & 3 deletions crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use indexmap::IndexSet;
use itertools::Itertools;
use smallvec::SmallVec;
use spacetimedb_client_api_messages::timestamp::Timestamp;
use spacetimedb_client_api_messages::websocket::{QueryUpdate, WebsocketFormat};
use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, WebsocketFormat};
use spacetimedb_data_structures::error_stream::ErrorStream;
use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
use spacetimedb_lib::identity::{AuthCtx, RequestId};
Expand Down Expand Up @@ -124,12 +124,12 @@ impl UpdatesRelValue<'_> {
!(self.deletes.is_empty() && self.inserts.is_empty())
}

pub fn encode<F: WebsocketFormat>(&self) -> (F::QueryUpdate, u64) {
pub fn encode<F: WebsocketFormat>(&self, compression: Compression) -> (F::QueryUpdate, u64) {
let (deletes, nr_del) = F::encode_list(self.deletes.iter());
let (inserts, nr_ins) = F::encode_list(self.inserts.iter());
let num_rows = nr_del + nr_ins;
let qu = QueryUpdate { deletes, inserts };
let cqu = F::into_query_update(qu);
let cqu = F::into_query_update(qu, compression);
(cqu, num_rows)
}
}
Expand Down
Loading

0 comments on commit 266d6fa

Please sign in to comment.