Skip to content

Commit

Permalink
Merge pull request #808 from input-output-hk/greg/799/pruning
Browse files Browse the repository at this point in the history
Add stake SQL store pruning
  • Loading branch information
ghubertpalo authored Mar 16, 2023
2 parents 7e2d358 + 2ac1fc6 commit a35fc91
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 21 deletions.
14 changes: 7 additions & 7 deletions mithril-aggregator/src/command_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use mithril_common::{
};

use crate::{
database::provider::StakePoolRepository,
database::provider::StakePoolStore,
event_store::{self, TransmitterService},
http_server::routes::router,
tools::{EraTools, GenesisTools, GenesisToolsDependency},
Expand Down Expand Up @@ -90,9 +90,9 @@ fn setup_genesis_dependencies(
)?),
config.store_retention_limit,
));
let stake_store = Arc::new(StakePoolRepository::new(Arc::new(Mutex::new(
Connection::open(sqlite_db_path.clone().unwrap())?,
))));
let stake_store = Arc::new(StakePoolStore::new(Arc::new(Mutex::new(Connection::open(
sqlite_db_path.clone().unwrap(),
)?))));
let single_signature_store = Arc::new(SingleSignatureStore::new(
Box::new(SQLiteAdapter::new("single_signature", sqlite_db_path)?),
config.store_retention_limit,
Expand Down Expand Up @@ -359,9 +359,9 @@ impl ServeCommand {
)?),
config.store_retention_limit,
));
let stake_store = Arc::new(StakePoolRepository::new(Arc::new(Mutex::new(
Connection::open(sqlite_db_path.clone().unwrap())?,
))));
let stake_store = Arc::new(StakePoolStore::new(Arc::new(Mutex::new(Connection::open(
sqlite_db_path.clone().unwrap(),
)?))));
let single_signature_store = Arc::new(SingleSignatureStore::new(
Box::new(SQLiteAdapter::new(
"single_signature",
Expand Down
187 changes: 181 additions & 6 deletions mithril-aggregator/src/database/provider/stake_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ use mithril_common::{

use mithril_common::StdError;

/// Delete stake pools for Epoch older than this.
const STAKE_POOL_PRUNE_EPOCH_THRESHOLD: Epoch = Epoch(2);

/// Stake pool as read from Chain.
/// TODO remove this compile directive ↓
#[allow(dead_code)]
#[derive(Debug, PartialEq)]
pub struct StakePool {
/// Pool Id
stake_pool_id: PartyId,
Expand Down Expand Up @@ -119,7 +121,7 @@ impl<'client> Provider<'client> for StakePoolProvider<'client> {
let aliases = SourceAlias::new(&[("{:stake_pool:}", "sp")]);
let projection = Self::Entity::get_projection().expand(aliases);

format!("select {projection} from stake_pool as sp where {condition}")
format!("select {projection} from stake_pool as sp where {condition} order by epoch asc, stake desc")
}
}

Expand Down Expand Up @@ -186,20 +188,62 @@ impl<'conn> Provider<'conn> for UpdateStakePoolProvider<'conn> {
}
}

/// Provider to remove old data from the stake_pool table
pub struct DeleteStakePoolProvider<'conn> {
connection: &'conn Connection,
}

impl<'conn> Provider<'conn> for DeleteStakePoolProvider<'conn> {
type Entity = StakePool;

fn get_connection(&'conn self) -> &'conn Connection {
self.connection
}

fn get_definition(&self, condition: &str) -> String {
// it is important to alias the fields with the same name as the table
// since the table cannot be aliased in a RETURNING statement in SQLite.
let projection = Self::Entity::get_projection()
.expand(SourceAlias::new(&[("{:stake_pool:}", "stake_pool")]));

format!("delete from stake_pool where {condition} returning {projection}")
}
}

impl<'conn> DeleteStakePoolProvider<'conn> {
/// Create a new instance
pub fn new(connection: &'conn Connection) -> Self {
Self { connection }
}

/// Create the SQL condition to prune data older than the given Epoch.
fn get_prune_condition(&self, epoch_threshold: Epoch) -> WhereCondition {
let epoch_value = Value::Integer(i64::try_from(epoch_threshold.0).unwrap());

WhereCondition::new("epoch < ?*", vec![epoch_value])
}

/// Prune the stake pools data older than the given epoch.
pub fn prune(&self, epoch_threshold: Epoch) -> Result<EntityCursor<StakePool>, StdError> {
let filters = self.get_prune_condition(epoch_threshold);

self.find(filters)
}
}
/// Service to deal with stake pools (read & write).
pub struct StakePoolRepository {
pub struct StakePoolStore {
connection: Arc<Mutex<Connection>>,
}

impl StakePoolRepository {
impl StakePoolStore {
/// Create a new StakePool service
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
Self { connection }
}
}

#[async_trait]
impl StakeStorer for StakePoolRepository {
impl StakeStorer for StakePoolStore {
async fn save_stakes(
&self,
epoch: Epoch,
Expand All @@ -221,6 +265,12 @@ impl StakeStorer for StakePoolRepository {
.map_err(|e| AdapterError::GeneralError(format!("{e}")))?;
new_stakes.insert(pool_id.to_string(), stake_pool.stake);
}
// Clean useless old stake distributions if needed.
if epoch > STAKE_POOL_PRUNE_EPOCH_THRESHOLD {
let _ = DeleteStakePoolProvider::new(connection)
.prune(epoch - STAKE_POOL_PRUNE_EPOCH_THRESHOLD)
.map_err(AdapterError::InitializationError)?;
}
connection
.execute("commit transaction")
.map_err(|e| AdapterError::QueryError(e.into()))?;
Expand Down Expand Up @@ -249,6 +299,8 @@ impl StakeStorer for StakePoolRepository {

#[cfg(test)]
mod tests {
use crate::database::migration::get_migrations;

use super::*;

#[test]
Expand Down Expand Up @@ -293,4 +345,127 @@ mod tests {
params
);
}

#[test]
fn prune() {
let connection = Connection::open(":memory:").unwrap();
let provider = DeleteStakePoolProvider::new(&connection);
let condition = provider.get_prune_condition(Epoch(5));
let (condition, params) = condition.expand();

assert_eq!("epoch < ?1".to_string(), condition);
assert_eq!(vec![Value::Integer(5)], params);
}

fn setup_db(connection: &Connection) -> Result<(), StdError> {
let migrations = get_migrations();
let migration =
migrations
.iter()
.find(|&m| m.version == 1)
.ok_or_else(|| -> StdError {
"There should be a migration version 1".to_string().into()
})?;
let query = {
// leverage the expanded parameter from this provider which is unit
// tested on its own above.
let update_provider = UpdateStakePoolProvider::new(connection);
let (sql_values, _) = update_provider
.get_update_condition("pool_id", Epoch(1), 1000)
.expand();

connection.execute(&migration.alterations)?;

format!("insert into stake_pool {sql_values}")
};
let stake_distribution: &[(&str, i64, i64); 9] = &[
("pool1", 1, 1000),
("pool2", 1, 1100),
("pool3", 1, 1300),
("pool1", 2, 1230),
("pool2", 2, 1090),
("pool3", 2, 1300),
("pool1", 3, 1250),
("pool2", 3, 1370),
("pool3", 3, 1300),
];
for (pool_id, epoch, stake) in stake_distribution {
let mut statement = connection.prepare(&query)?;

statement.bind(1, *pool_id).unwrap();
statement.bind(2, *epoch).unwrap();
statement.bind(3, *stake).unwrap();
statement.next().unwrap();
}

Ok(())
}

#[test]
fn test_get_stake_pools() {
let connection = Connection::open(":memory:").unwrap();
setup_db(&connection).unwrap();

let provider = StakePoolProvider::new(&connection);
let mut cursor = provider.get_by_epoch(&Epoch(1)).unwrap();

let stake_pool = cursor.next().expect("Should have a stake pool 'pool1'.");
assert_eq!("pool3".to_string(), stake_pool.stake_pool_id);
assert_eq!(Epoch(1), stake_pool.epoch);
assert_eq!(1300, stake_pool.stake);
assert_eq!(2, cursor.count());

let mut cursor = provider.get_by_epoch(&Epoch(3)).unwrap();

let stake_pool = cursor.next().expect("Should have a stake pool 'pool2'.");
assert_eq!("pool2".to_string(), stake_pool.stake_pool_id);
assert_eq!(Epoch(3), stake_pool.epoch);
assert_eq!(1370, stake_pool.stake);
assert_eq!(2, cursor.count());

let cursor = provider.get_by_epoch(&Epoch(5)).unwrap();
assert_eq!(0, cursor.count());
}

#[test]
fn test_update_stakes() {
let connection = Connection::open(":memory:").unwrap();
setup_db(&connection).unwrap();

let provider = UpdateStakePoolProvider::new(&connection);
let stake_pool = provider.persist("pool4", Epoch(3), 9999).unwrap();

assert_eq!("pool4".to_string(), stake_pool.stake_pool_id);
assert_eq!(Epoch(3), stake_pool.epoch);
assert_eq!(9999, stake_pool.stake);

let provider = StakePoolProvider::new(&connection);
let mut cursor = provider.get_by_epoch(&Epoch(3)).unwrap();
let stake_pool = cursor.next().expect("Should have a stake pool 'pool4'.");

assert_eq!("pool4".to_string(), stake_pool.stake_pool_id);
assert_eq!(Epoch(3), stake_pool.epoch);
assert_eq!(9999, stake_pool.stake);
assert_eq!(3, cursor.count());
}

#[test]
fn test_prune() {
let connection = Connection::open(":memory:").unwrap();
setup_db(&connection).unwrap();

let provider = DeleteStakePoolProvider::new(&connection);
let cursor = provider.prune(Epoch(2)).unwrap();

assert_eq!(3, cursor.count());

let provider = StakePoolProvider::new(&connection);
let cursor = provider.get_by_epoch(&Epoch(1)).unwrap();

assert_eq!(0, cursor.count());

let cursor = provider.get_by_epoch(&Epoch(2)).unwrap();

assert_eq!(3, cursor.count());
}
}
11 changes: 9 additions & 2 deletions mithril-common/src/entities/epoch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ impl Sub for Epoch {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 - rhs.0)
Self(self.0.saturating_sub(rhs.0))
}
}

impl Sub<u64> for Epoch {
type Output = Self;

fn sub(self, rhs: u64) -> Self::Output {
Self(self.0 - rhs)
Self(self.0.saturating_sub(rhs))
}
}

Expand Down Expand Up @@ -181,9 +181,16 @@ mod tests {
assert_eq!(Epoch(8), epoch);
}

#[test]
fn saturating_sub() {
assert_eq!(Epoch(0), Epoch(1) - Epoch(5));
assert_eq!(Epoch(0), Epoch(1) - 5_u64);
}

#[test]
fn test_previous() {
assert_eq!(Epoch(2), Epoch(3).previous().unwrap());
assert!(Epoch(0).previous().is_err());
}

#[test]
Expand Down
11 changes: 5 additions & 6 deletions mithril-common/src/sqlite/source_alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ mod tests {
#[test]
fn simple_source_alias() {
let source_alias = SourceAlias::new(&[("first", "one"), ("second", "two")]);
let target = source_alias
.get_iterator()
.map(|(name, alias)| format!("{name} => {alias}"))
.collect::<Vec<String>>()
.join(", ");
let mut fields = "first.one, second.two".to_string();

assert_eq!("first => one, second => two".to_string(), target);
for (alias, source) in source_alias.get_iterator() {
fields = fields.replace(alias, source);
}
assert_eq!("one.one, two.two".to_string(), fields);
}
}

0 comments on commit a35fc91

Please sign in to comment.