Skip to content

Commit

Permalink
persist counter state
Browse files Browse the repository at this point in the history
  • Loading branch information
sinui0 committed Apr 27, 2024
1 parent 9fb7fe1 commit 4038fba
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mpz-core/src/prg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Implement AES-based PRG.

use std::collections::HashMap;

use crate::{aes::AesEncryptor, Block};
use rand::Rng;
use rand_core::{
Expand All @@ -11,6 +13,8 @@ use rand_core::{
#[derive(Clone)]
struct PrgCore {
aes: AesEncryptor,
// Stores the counter for each stream id.
state: HashMap<u64, u64>,
stream_id: u64,
counter: u64,
}
Expand Down Expand Up @@ -48,6 +52,7 @@ impl SeedableRng for PrgCore {
let aes = AesEncryptor::new(seed);
Self {
aes,
state: Default::default(),
stream_id: 0u64,
counter: 0u64,
}
Expand Down Expand Up @@ -123,7 +128,13 @@ impl Prg {

/// Sets the stream id.
pub fn set_stream_id(&mut self, stream_id: u64) {
let state = &mut self.0.core.state;
state.insert(self.0.core.stream_id, self.0.core.counter);

let counter = state.get(&stream_id).copied().unwrap_or(0);

self.0.core.stream_id = stream_id;
self.0.core.counter = counter;
}

/// Generate a random bool value.
Expand Down Expand Up @@ -195,4 +206,19 @@ mod tests {

assert_ne!(x[0], y[0]);
}

#[test]
fn test_prg_state_persisted() {
let mut prg = Prg::from_seed(Block::ZERO);
let mut x = vec![Block::ZERO; 2];
prg.random_blocks(&mut x);

let counter = prg.counter();
assert_ne!(counter, 0);

prg.set_stream_id(1);
prg.set_stream_id(0);

assert_eq!(prg.counter(), counter);
}
}

0 comments on commit 4038fba

Please sign in to comment.