diff --git a/Cargo.toml b/Cargo.toml index 8796652b4..1876e61e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ uuid = { version = "1", default-features = false, optional = true } ouroboros = { version = "0.15", default-features = false } url = { version = "2.2", default-features = false } thiserror = { version = "1", default-features = false } +async-broadcast = { version = "0.5" } [dev-dependencies] smol = { version = "1.2" } @@ -60,6 +61,7 @@ pretty_assertions = { version = "0.7" } time = { version = "0.3", features = ["macros"] } uuid = { version = "1", features = ["v4"] } once_cell = "1.8" +async-channel = { version = "^1.7" } [features] debug-print = [] diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index d9095d758..0fc9dd5a6 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,6 +1,7 @@ use crate::{ - error::*, AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, - QueryResult, Statement, StatementBuilder, StreamTrait, TransactionError, TransactionTrait, + error::*, AccessMode, ConnectionTrait, DatabaseTransaction, EventStream, ExecResult, + IsolationLevel, QueryResult, Statement, StatementBuilder, StreamTrait, TransactionError, + TransactionTrait, }; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; use std::{future::Future, pin::Pin}; @@ -404,6 +405,15 @@ impl DatabaseConnection { } } } + + pub fn set_event_stream(&mut self, event_stream: E) -> E::Receiver + where + E: EventStream, + { + let (sender, receiver) = event_stream.subscribe(); + // TODO: Save the `sender` in `DatabaseConnection` + receiver + } } #[cfg(feature = "sea-orm-internal")] diff --git a/src/database/event.rs b/src/database/event.rs new file mode 100644 index 000000000..15fc1af8f --- /dev/null +++ b/src/database/event.rs @@ -0,0 +1,73 @@ +use crate::{DbErr, EntityTrait}; +use async_trait::async_trait; +use sea_query::{DynIden, Value}; +use std::{any::TypeId, collections::HashMap, fmt::Debug}; + +pub trait EventStream { + type Sender: EventSender; + type Receiver: EventReceiver; + + fn subscribe(self) -> (Self::Sender, Self::Receiver); +} + +#[async_trait] +pub trait EventSender { + async fn send(&self, event: Event) -> Result<(), DbErr>; +} + +#[async_trait] +pub trait EventReceiver { + async fn recv(&mut self) -> Result; +} + +#[derive(Debug, Clone)] +pub struct Event { + pub entity_type_id: TypeId, + pub action: EventAction, + pub values: HashMap, +} + +#[derive(Debug, Clone)] +pub enum EventAction { + Insert, + Update, + Delete, +} + +impl Event { + pub fn of_entity(&self) -> bool + where + E: EntityTrait, + { + self.entity_type_id == TypeId::of::() + } +} + +mod impl_event_stream_for_async_broadcast { + use super::*; + use async_broadcast::{Receiver, Sender}; + use futures::FutureExt; + + impl EventStream for (Sender, Receiver) { + type Sender = Sender; + type Receiver = Receiver; + + fn subscribe(self) -> (Self::Sender, Self::Receiver) { + self + } + } + + #[async_trait] + impl EventSender for Sender { + async fn send(&self, event: Event) -> Result<(), DbErr> { + self.broadcast(event).await.map(|_| ()).map_err(|e| todo!()) + } + } + + #[async_trait] + impl EventReceiver for Receiver { + async fn recv(&mut self) -> Result { + self.recv().await.map_err(|e| todo!()) + } + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 795789e6f..6f2706f8e 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -2,6 +2,7 @@ use std::time::Duration; mod connection; mod db_connection; +mod event; #[cfg(feature = "mock")] mod mock; mod statement; @@ -10,6 +11,7 @@ mod transaction; pub use connection::*; pub use db_connection::*; +pub use event::*; #[cfg(feature = "mock")] pub use mock::*; pub use statement::*; diff --git a/tests/basic.rs b/tests/basic.rs index a7ec72f26..f99235e27 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -10,7 +10,18 @@ pub use sea_orm::{entity::*, error::*, query::*, sea_query, tests_cfg::*, Databa async fn main() -> Result<(), DbErr> { let base_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_owned()); - let db: DbConn = Database::connect(&base_url).await?; + let mut db: DbConn = Database::connect(&base_url).await?; + + let mut tokio_receiver = db.set_event_stream(async_broadcast::broadcast(10)); + + while let Ok(event) = tokio_receiver.recv().await { + if event.of_entity::() { + if let Some(val) = event.values.get(cake::Column::Name.as_str()) { + todo!() + } + } + } + setup_schema(&db).await?; crud_cake(&db).await?;