Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add context to checks #241

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tap_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ anyhow.workspace = true
rand.workspace = true
thiserror = "1.0.38"
async-trait = "0.1.72"
anymap3 = "1.0.0"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_std"] }
Expand Down
6 changes: 3 additions & 3 deletions tap_core/src/manager/context/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ pub mod checks {
receipt::{
checks::{Check, CheckError, CheckResult, ReceiptCheck},
state::Checking,
ReceiptError, ReceiptWithState,
Context, ReceiptError, ReceiptWithState,
},
signed_message::MessageId,
};
Expand Down Expand Up @@ -296,7 +296,7 @@ pub mod checks {

#[async_trait::async_trait]
impl Check for AllocationIdCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let received_allocation_id = receipt.signed_receipt().message.allocation_id;
if self
.allocation_ids
Expand All @@ -323,7 +323,7 @@ pub mod checks {

#[async_trait::async_trait]
impl Check for SignatureCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
Expand Down
5 changes: 3 additions & 2 deletions tap_core/src/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
//! ReceiptWithState,
//! state::Checking,
//! checks::CheckList,
//! ReceiptError
//! ReceiptError,
//! Context
//! },
//! manager::{
//! Manager,
Expand Down Expand Up @@ -70,7 +71,7 @@
//! let receipt = EIP712SignedMessage::new(&domain_separator, message, &wallet).unwrap();
//!
//! let manager = Manager::new(domain_separator, MyContext, CheckList::empty());
//! manager.verify_and_store_receipt(receipt).await.unwrap()
//! manager.verify_and_store_receipt(&Context::new(), receipt).await.unwrap()
//! # }
//! ```
//!
Expand Down
11 changes: 7 additions & 4 deletions tap_core/src/manager/tap_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
receipt::{
checks::{CheckBatch, CheckList, TimestampCheck, UniqueCheck},
state::{Failed, Reserved},
ReceiptError, ReceiptWithState, SignedReceipt,
Context, ReceiptError, ReceiptWithState, SignedReceipt,
},
Error,
};
Expand Down Expand Up @@ -99,6 +99,7 @@ where
{
async fn collect_receipts(
&self,
ctx: &Context,
timestamp_buffer_ns: u64,
min_timestamp_ns: u64,
limit: Option<u64>,
Expand Down Expand Up @@ -140,7 +141,7 @@ where

for receipt in checking_receipts.into_iter() {
let receipt = receipt
.finalize_receipt_checks(&self.checks)
.finalize_receipt_checks(ctx, &self.checks)
.await
.map_err(|e| Error::ReceiptError(ReceiptError::RetryableCheck(e)))?;

Expand Down Expand Up @@ -184,6 +185,7 @@ where
///
pub async fn create_rav_request(
&self,
ctx: &Context,
timestamp_buffer_ns: u64,
receipts_limit: Option<u64>,
) -> Result<RAVRequest, Error> {
Expand All @@ -194,7 +196,7 @@ where
.unwrap_or(0);

let (valid_receipts, invalid_receipts) = self
.collect_receipts(timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
.collect_receipts(ctx, timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
.await?;

let expected_rav = Self::generate_expected_rav(&valid_receipts, previous_rav.clone());
Expand Down Expand Up @@ -271,12 +273,13 @@ where
///
pub async fn verify_and_store_receipt(
&self,
ctx: &Context,
signed_receipt: SignedReceipt,
) -> std::result::Result<(), Error> {
let mut received_receipt = ReceiptWithState::new(signed_receipt);

// perform checks
received_receipt.perform_checks(&self.checks).await?;
received_receipt.perform_checks(ctx, &self.checks).await?;

// store the receipt
self.context
Expand Down
2 changes: 1 addition & 1 deletion tap_core/src/rav.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//! 1. Create a [`RAVRequest`] with the valid receipts and the previous RAV.
//! 2. Send the request to the aggregator.
//! 3. The aggregator will verify the request and increment the total amount that
//! has been aggregated.
//! has been aggregated.
//! 4. The aggregator will return a [`SignedRAV`].
//! 5. Store the [`SignedRAV`].
//! 6. Repeat the process until the allocation is closed.
Expand Down
10 changes: 5 additions & 5 deletions tap_core/src/receipt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
//! # use std::sync::Arc;
//! use tap_core::{
//! receipt::checks::{Check, CheckResult, ReceiptCheck},
//! receipt::{ReceiptWithState, state::Checking}
//! receipt::{Context, ReceiptWithState, state::Checking}
//! };
//! # use async_trait::async_trait;
//!
//! struct MyCheck;
//!
//! #[async_trait]
//! impl Check for MyCheck {
//! async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
//! async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
//! // Implement your check here
//! Ok(())
//! }
Expand All @@ -33,7 +33,7 @@ use crate::signed_message::{SignatureBytes, SignatureBytesExt};

use super::{
state::{Checking, Failed},
ReceiptError, ReceiptWithState,
Context, ReceiptError, ReceiptWithState,
};
use std::{
collections::HashSet,
Expand Down Expand Up @@ -80,7 +80,7 @@ impl Deref for CheckList {
/// Check trait is implemented by the lib user to validate receipts before they are stored.
#[async_trait::async_trait]
pub trait Check {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult;
async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult;
}

/// CheckBatch is mostly used by the lib to implement checks
Expand Down Expand Up @@ -119,7 +119,7 @@ impl StatefulTimestampCheck {

#[async_trait::async_trait]
impl Check for StatefulTimestampCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap();
let signed_receipt = receipt.signed_receipt();
if signed_receipt.message.timestamp_ns <= min_timestamp_ns {
Expand Down
2 changes: 2 additions & 0 deletions tap_core/src/receipt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ pub type SignedReceipt = EIP712SignedMessage<Receipt>;

/// Result type for receipt
pub type ReceiptResult<T> = Result<T, ReceiptError>;

pub type Context = anymap3::Map<dyn std::any::Any + Send + Sync>;
24 changes: 14 additions & 10 deletions tap_core/src/receipt/received_receipt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
use alloy::dyn_abi::Eip712Domain;

use super::checks::CheckError;
use super::{Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use super::{Context, Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use crate::receipt::state::{AwaitingReserve, Checking, Failed, ReceiptState, Reserved};
use crate::{
manager::adapters::EscrowHandler, receipt::checks::ReceiptCheck,
Expand All @@ -28,16 +28,15 @@ pub type ResultReceipt<S> = std::result::Result<ReceiptWithState<S>, ReceiptWith
/// Typestate pattern for tracking the state of a receipt
///
/// - The [ `ReceiptState` ] trait represents the different states a receipt
/// can be in.
/// can be in.
/// - The [ `Checking` ] state is used to represent a receipt that is currently
/// being checked.
/// being checked.
/// - The [ `Failed` ] state is used to represent a receipt that has failed a
/// check or validation.
/// check or validation.
/// - The [ `AwaitingReserve` ] state is used to represent a receipt that has
/// passed all checks and is
/// awaiting escrow reservation.
/// passed all checks and is awaiting escrow reservation.
/// - The [ `Reserved` ] state is used to represent a receipt that has
/// successfully reserved escrow.
/// successfully reserved escrow.
#[derive(Debug, Clone)]
pub struct ReceiptWithState<S>
where
Expand Down Expand Up @@ -90,10 +89,14 @@ impl ReceiptWithState<Checking> {
/// cannot be comleted in the receipts current internal state.
/// All other checks must be complete before `CheckAndReserveEscrow`.
///
pub async fn perform_checks(&mut self, checks: &[ReceiptCheck]) -> ReceiptResult<()> {
pub async fn perform_checks(
&mut self,
ctx: &Context,
checks: &[ReceiptCheck],
) -> ReceiptResult<()> {
for check in checks {
// return early on an error
check.check(self).await.map_err(|e| match e {
check.check(ctx, self).await.map_err(|e| match e {
CheckError::Retryable(e) => ReceiptError::RetryableCheck(e.to_string()),
CheckError::Failed(e) => ReceiptError::CheckFailure(e.to_string()),
})?;
Expand All @@ -108,9 +111,10 @@ impl ReceiptWithState<Checking> {
///
pub async fn finalize_receipt_checks(
mut self,
ctx: &Context,
checks: &[ReceiptCheck],
) -> Result<ResultReceipt<AwaitingReserve>, String> {
let all_checks_passed = self.perform_checks(checks).await;
let all_checks_passed = self.perform_checks(ctx, checks).await;
if let Err(ReceiptError::RetryableCheck(e)) = all_checks_passed {
Err(e.to_string())
} else if let Err(e) = all_checks_passed {
Expand Down
41 changes: 24 additions & 17 deletions tap_core/tests/manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tap_core::{
receipt::{
checks::{Check, CheckError, CheckList, StatefulTimestampCheck},
state::Checking,
Receipt, ReceiptWithState,
Context, Receipt, ReceiptWithState,
},
signed_message::EIP712SignedMessage,
tap_eip712_domain,
Expand Down Expand Up @@ -145,7 +145,7 @@ async fn manager_verify_and_store_varying_initial_checks(
.insert(signer.address(), 999999);

assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
}
Expand Down Expand Up @@ -184,11 +184,11 @@ async fn manager_create_rav_request_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -279,12 +279,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -323,12 +323,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
}
let rav_request_result = manager.create_rav_request(0, None).await;
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_result.is_ok());

let rav_request = rav_request_result.unwrap();
Expand Down Expand Up @@ -391,7 +391,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
Expand All @@ -403,7 +403,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
manager.remove_obsolete_receipts().await.unwrap();
}

let rav_request_1_result = manager.create_rav_request(0, None).await;
let rav_request_1_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_1_result.is_ok());

let rav_request_1 = rav_request_1_result.unwrap();
Expand Down Expand Up @@ -438,7 +438,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
stored_signed_receipts.push(signed_receipt.clone());
query_appraisals.write().unwrap().insert(query_id, value);
assert!(manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.is_ok());
expected_accumulated_value += value;
Expand All @@ -458,7 +458,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
);
}

let rav_request_2_result = manager.create_rav_request(0, None).await;
let rav_request_2_result = manager.create_rav_request(&Context::new(), 0, None).await;
assert!(rav_request_2_result.is_ok());

let rav_request_2 = rav_request_2_result.unwrap();
Expand Down Expand Up @@ -518,12 +518,15 @@ async fn manager_create_rav_and_ignore_invalid_receipts(
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
stored_signed_receipts.push(signed_receipt.clone());
manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.unwrap();
}

let rav_request = manager.create_rav_request(0, None).await.unwrap();
let rav_request = manager
.create_rav_request(&Context::new(), 0, None)
.await
.unwrap();
let expected_rav = rav_request.expected_rav.unwrap();

assert_eq!(rav_request.valid_receipts.len(), 1);
Expand All @@ -544,7 +547,11 @@ async fn test_retryable_checks(

#[async_trait::async_trait]
impl Check for RetryableCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> Result<(), CheckError> {
async fn check(
&self,
_: &Context,
receipt: &ReceiptWithState<Checking>,
) -> Result<(), CheckError> {
// we want to fail only if nonce is 5 and if is create rav step
if self.0.load(std::sync::atomic::Ordering::SeqCst)
&& receipt.signed_receipt().message.nonce == 5
Expand Down Expand Up @@ -591,14 +598,14 @@ async fn test_retryable_checks(
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
stored_signed_receipts.push(signed_receipt.clone());
manager
.verify_and_store_receipt(signed_receipt)
.verify_and_store_receipt(&Context::new(), signed_receipt)
.await
.unwrap();
}

is_create_rav.store(true, std::sync::atomic::Ordering::SeqCst);

let rav_request = manager.create_rav_request(0, None).await;
let rav_request = manager.create_rav_request(&Context::new(), 0, None).await;

assert_eq!(
rav_request.expect_err("Didn't fail").to_string(),
Expand Down
Loading
Loading