diff --git a/mithril-aggregator/src/command_args.rs b/mithril-aggregator/src/command_args.rs index 520a9eff653..3cdf6d36629 100644 --- a/mithril-aggregator/src/command_args.rs +++ b/mithril-aggregator/src/command_args.rs @@ -34,7 +34,7 @@ use mithril_common::{ }; use crate::{ - database::provider::StakePoolRepository, + database::provider::StakePoolStore, event_store::{self, TransmitterService}, http_server::routes::router, tools::{EraTools, GenesisTools, GenesisToolsDependency}, @@ -90,9 +90,9 @@ fn setup_genesis_dependencies( )?), config.store_retention_limit, )); - let stake_store = Arc::new(StakePoolRepository::new(Arc::new(Mutex::new( - Connection::open(sqlite_db_path.clone().unwrap())?, - )))); + let stake_store = Arc::new(StakePoolStore::new(Arc::new(Mutex::new(Connection::open( + sqlite_db_path.clone().unwrap(), + )?)))); let single_signature_store = Arc::new(SingleSignatureStore::new( Box::new(SQLiteAdapter::new("single_signature", sqlite_db_path)?), config.store_retention_limit, @@ -359,9 +359,9 @@ impl ServeCommand { )?), config.store_retention_limit, )); - let stake_store = Arc::new(StakePoolRepository::new(Arc::new(Mutex::new( - Connection::open(sqlite_db_path.clone().unwrap())?, - )))); + let stake_store = Arc::new(StakePoolStore::new(Arc::new(Mutex::new(Connection::open( + sqlite_db_path.clone().unwrap(), + )?)))); let single_signature_store = Arc::new(SingleSignatureStore::new( Box::new(SQLiteAdapter::new( "single_signature", diff --git a/mithril-aggregator/src/database/provider/stake_pool.rs b/mithril-aggregator/src/database/provider/stake_pool.rs index 14c098dec01..3898518bf93 100644 --- a/mithril-aggregator/src/database/provider/stake_pool.rs +++ b/mithril-aggregator/src/database/provider/stake_pool.rs @@ -18,9 +18,11 @@ use mithril_common::{ use mithril_common::StdError; +/// Delete stake pools for Epoch older than this. +const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: Epoch = Epoch(2); + /// Stake pool as read from Chain. -/// TODO remove this compile directive ↓ -#[allow(dead_code)] +#[derive(Debug, PartialEq)] pub struct StakePool { /// Pool Id stake_pool_id: PartyId, @@ -119,7 +121,7 @@ impl<'client> Provider<'client> for StakePoolProvider<'client> { let aliases = SourceAlias::new(&[("{:stake_pool:}", "sp")]); let projection = Self::Entity::get_projection().expand(aliases); - format!("select {projection} from stake_pool as sp where {condition}") + format!("select {projection} from stake_pool as sp where {condition} order by epoch asc, stake desc") } } @@ -186,12 +188,54 @@ impl<'conn> Provider<'conn> for UpdateStakePoolProvider<'conn> { } } +/// Provider to remove old data from the stake_pool table +pub struct DeleteStakePoolProvider<'conn> { + connection: &'conn Connection, +} + +impl<'conn> Provider<'conn> for DeleteStakePoolProvider<'conn> { + type Entity = StakePool; + + fn get_connection(&'conn self) -> &'conn Connection { + self.connection + } + + fn get_definition(&self, condition: &str) -> String { + // it is important to alias the fields with the same name as the table + // since the table cannot be aliased in a RETURNING statement in SQLite. + let projection = Self::Entity::get_projection() + .expand(SourceAlias::new(&[("{:stake_pool:}", "stake_pool")])); + + format!("delete from stake_pool where {condition} returning {projection}") + } +} + +impl<'conn> DeleteStakePoolProvider<'conn> { + /// Create a new instance + pub fn new(connection: &'conn Connection) -> Self { + Self { connection } + } + + /// Create the SQL condition to prune data older than the given Epoch. + fn get_prune_condition(&self, epoch_threshold: Epoch) -> WhereCondition { + let epoch_value = Value::Integer(i64::try_from(epoch_threshold.0).unwrap()); + + WhereCondition::new("epoch < ?*", vec![epoch_value]) + } + + /// Prune the stake pools data older than the given epoch. + pub fn prune(&self, epoch_threshold: Epoch) -> Result, StdError> { + let filters = self.get_prune_condition(epoch_threshold); + + self.find(filters) + } +} /// Service to deal with stake pools (read & write). -pub struct StakePoolRepository { +pub struct StakePoolStore { connection: Arc>, } -impl StakePoolRepository { +impl StakePoolStore { /// Create a new StakePool service pub fn new(connection: Arc>) -> Self { Self { connection } @@ -199,7 +243,7 @@ impl StakePoolRepository { } #[async_trait] -impl StakeStorer for StakePoolRepository { +impl StakeStorer for StakePoolStore { async fn save_stakes( &self, epoch: Epoch, @@ -221,6 +265,12 @@ impl StakeStorer for StakePoolRepository { .map_err(|e| AdapterError::GeneralError(format!("{e}")))?; new_stakes.insert(pool_id.to_string(), stake_pool.stake); } + // Clean useless old stake distributions if needed. + if epoch > STAKE_POOL_PRUNE_EPOCH_THRESHOLD { + let _ = DeleteStakePoolProvider::new(connection) + .prune(epoch - STAKE_POOL_PRUNE_EPOCH_THRESHOLD) + .map_err(AdapterError::InitializationError)?; + } connection .execute("commit transaction") .map_err(|e| AdapterError::QueryError(e.into()))?; @@ -249,6 +299,8 @@ impl StakeStorer for StakePoolRepository { #[cfg(test)] mod tests { + use crate::database::migration::get_migrations; + use super::*; #[test] @@ -293,4 +345,127 @@ mod tests { params ); } + + #[test] + fn prune() { + let connection = Connection::open(":memory:").unwrap(); + let provider = DeleteStakePoolProvider::new(&connection); + let condition = provider.get_prune_condition(Epoch(5)); + let (condition, params) = condition.expand(); + + assert_eq!("epoch < ?1".to_string(), condition); + assert_eq!(vec![Value::Integer(5)], params); + } + + fn setup_db(connection: &Connection) -> Result<(), StdError> { + let migrations = get_migrations(); + let migration = + migrations + .iter() + .find(|&m| m.version == 1) + .ok_or_else(|| -> StdError { + "There should be a migration version 1".to_string().into() + })?; + let query = { + // leverage the expanded parameter from this provider which is unit + // tested on its own above. + let update_provider = UpdateStakePoolProvider::new(connection); + let (sql_values, _) = update_provider + .get_update_condition("pool_id", Epoch(1), 1000) + .expand(); + + connection.execute(&migration.alterations)?; + + format!("insert into stake_pool {sql_values}") + }; + let stake_distribution: &[(&str, i64, i64); 9] = &[ + ("pool1", 1, 1000), + ("pool2", 1, 1100), + ("pool3", 1, 1300), + ("pool1", 2, 1230), + ("pool2", 2, 1090), + ("pool3", 2, 1300), + ("pool1", 3, 1250), + ("pool2", 3, 1370), + ("pool3", 3, 1300), + ]; + for (pool_id, epoch, stake) in stake_distribution { + let mut statement = connection.prepare(&query)?; + + statement.bind(1, *pool_id).unwrap(); + statement.bind(2, *epoch).unwrap(); + statement.bind(3, *stake).unwrap(); + statement.next().unwrap(); + } + + Ok(()) + } + + #[test] + fn test_get_stake_pools() { + let connection = Connection::open(":memory:").unwrap(); + setup_db(&connection).unwrap(); + + let provider = StakePoolProvider::new(&connection); + let mut cursor = provider.get_by_epoch(&Epoch(1)).unwrap(); + + let stake_pool = cursor.next().expect("Should have a stake pool 'pool1'."); + assert_eq!("pool3".to_string(), stake_pool.stake_pool_id); + assert_eq!(Epoch(1), stake_pool.epoch); + assert_eq!(1300, stake_pool.stake); + assert_eq!(2, cursor.count()); + + let mut cursor = provider.get_by_epoch(&Epoch(3)).unwrap(); + + let stake_pool = cursor.next().expect("Should have a stake pool 'pool2'."); + assert_eq!("pool2".to_string(), stake_pool.stake_pool_id); + assert_eq!(Epoch(3), stake_pool.epoch); + assert_eq!(1370, stake_pool.stake); + assert_eq!(2, cursor.count()); + + let cursor = provider.get_by_epoch(&Epoch(5)).unwrap(); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_update_stakes() { + let connection = Connection::open(":memory:").unwrap(); + setup_db(&connection).unwrap(); + + let provider = UpdateStakePoolProvider::new(&connection); + let stake_pool = provider.persist("pool4", Epoch(3), 9999).unwrap(); + + assert_eq!("pool4".to_string(), stake_pool.stake_pool_id); + assert_eq!(Epoch(3), stake_pool.epoch); + assert_eq!(9999, stake_pool.stake); + + let provider = StakePoolProvider::new(&connection); + let mut cursor = provider.get_by_epoch(&Epoch(3)).unwrap(); + let stake_pool = cursor.next().expect("Should have a stake pool 'pool4'."); + + assert_eq!("pool4".to_string(), stake_pool.stake_pool_id); + assert_eq!(Epoch(3), stake_pool.epoch); + assert_eq!(9999, stake_pool.stake); + assert_eq!(3, cursor.count()); + } + + #[test] + fn test_prune() { + let connection = Connection::open(":memory:").unwrap(); + setup_db(&connection).unwrap(); + + let provider = DeleteStakePoolProvider::new(&connection); + let cursor = provider.prune(Epoch(2)).unwrap(); + + assert_eq!(3, cursor.count()); + + let provider = StakePoolProvider::new(&connection); + let cursor = provider.get_by_epoch(&Epoch(1)).unwrap(); + + assert_eq!(0, cursor.count()); + + let cursor = provider.get_by_epoch(&Epoch(2)).unwrap(); + + assert_eq!(3, cursor.count()); + } } diff --git a/mithril-common/src/entities/epoch.rs b/mithril-common/src/entities/epoch.rs index d0d426fab47..9b86d50b76b 100644 --- a/mithril-common/src/entities/epoch.rs +++ b/mithril-common/src/entities/epoch.rs @@ -99,7 +99,7 @@ impl Sub for Epoch { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 - rhs.0) + Self(self.0.saturating_sub(rhs.0)) } } @@ -107,7 +107,7 @@ impl Sub for Epoch { type Output = Self; fn sub(self, rhs: u64) -> Self::Output { - Self(self.0 - rhs) + Self(self.0.saturating_sub(rhs)) } } @@ -181,9 +181,16 @@ mod tests { assert_eq!(Epoch(8), epoch); } + #[test] + fn saturating_sub() { + assert_eq!(Epoch(0), Epoch(1) - Epoch(5)); + assert_eq!(Epoch(0), Epoch(1) - 5_u64); + } + #[test] fn test_previous() { assert_eq!(Epoch(2), Epoch(3).previous().unwrap()); + assert!(Epoch(0).previous().is_err()); } #[test] diff --git a/mithril-common/src/sqlite/source_alias.rs b/mithril-common/src/sqlite/source_alias.rs index c3df58feb6c..a0bde2a6a35 100644 --- a/mithril-common/src/sqlite/source_alias.rs +++ b/mithril-common/src/sqlite/source_alias.rs @@ -31,12 +31,11 @@ mod tests { #[test] fn simple_source_alias() { let source_alias = SourceAlias::new(&[("first", "one"), ("second", "two")]); - let target = source_alias - .get_iterator() - .map(|(name, alias)| format!("{name} => {alias}")) - .collect::>() - .join(", "); + let mut fields = "first.one, second.two".to_string(); - assert_eq!("first => one, second => two".to_string(), target); + for (alias, source) in source_alias.get_iterator() { + fields = fields.replace(alias, source); + } + assert_eq!("one.one, two.two".to_string(), fields); } }