Skip to content

Commit

Permalink
Merge pull request #233 from semiotic-ai/gusinacio/check-error-recove…
Browse files Browse the repository at this point in the history
…rable

feat!: add retryable errors to checks
  • Loading branch information
gusinacio authored Aug 19, 2024
2 parents 9e1915b + 51f04cb commit ff856d9
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 32 deletions.
32 changes: 21 additions & 11 deletions tap_core/src/manager/context/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl EscrowHandler for InMemoryContext {
pub mod checks {
use crate::{
receipt::{
checks::{Check, CheckResult, ReceiptCheck},
checks::{Check, CheckError, CheckResult, ReceiptCheck},
state::Checking,
ReceiptError, ReceiptWithState,
},
Expand Down Expand Up @@ -306,10 +306,12 @@ pub mod checks {
{
Ok(())
} else {
Err(ReceiptError::InvalidAllocationID {
received_allocation_id,
}
.into())
Err(CheckError::Failed(
ReceiptError::InvalidAllocationID {
received_allocation_id,
}
.into(),
))
}
}
}
Expand All @@ -325,14 +327,22 @@ pub mod checks {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
.map_err(|e| ReceiptError::InvalidSignature {
source_error_message: e.to_string(),
.map_err(|e| {
CheckError::Failed(
ReceiptError::InvalidSignature {
source_error_message: e.to_string(),
}
.into(),
)
})?;

if !self.valid_signers.contains(&recovered_address) {
Err(ReceiptError::InvalidSignature {
source_error_message: "Invalid signer".to_string(),
}
.into())
Err(CheckError::Failed(
ReceiptError::InvalidSignature {
source_error_message: "Invalid signer".to_string(),
}
.into(),
))
} else {
Ok(())
}
Expand Down
7 changes: 5 additions & 2 deletions tap_core/src/manager/tap_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
receipt::{
checks::{CheckBatch, CheckList, TimestampCheck, UniqueCheck},
state::{Failed, Reserved},
ReceiptWithState, SignedReceipt,
ReceiptError, ReceiptWithState, SignedReceipt,
},
Error,
};
Expand Down Expand Up @@ -139,7 +139,10 @@ where
failed_receipts.extend(already_failed);

for receipt in checking_receipts.into_iter() {
let receipt = receipt.finalize_receipt_checks(&self.checks).await;
let receipt = receipt
.finalize_receipt_checks(&self.checks)
.await
.map_err(|e| Error::ReceiptError(ReceiptError::RetryableCheck(e)))?;

match receipt {
Ok(checked) => awaiting_reserve_receipts.push(checked),
Expand Down
22 changes: 16 additions & 6 deletions tap_core/src/receipt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ use std::{
pub type ReceiptCheck = Arc<dyn Check + Sync + Send>;

/// Result of a check operation. It uses the `anyhow` crate to handle errors.
pub type CheckResult = anyhow::Result<()>;
pub type CheckResult = Result<(), CheckError>;

#[derive(thiserror::Error, Debug)]
pub enum CheckError {
#[error(transparent)]
Retryable(anyhow::Error),
#[error(transparent)]
Failed(anyhow::Error),
}

/// CheckList is a NewType pattern to store a list of checks.
/// It is a wrapper around an Arc of ReceiptCheck[].
Expand Down Expand Up @@ -115,11 +123,13 @@ impl Check for StatefulTimestampCheck {
let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap();
let signed_receipt = receipt.signed_receipt();
if signed_receipt.message.timestamp_ns <= min_timestamp_ns {
return Err(ReceiptError::InvalidTimestamp {
received_timestamp: signed_receipt.message.timestamp_ns,
timestamp_min: min_timestamp_ns,
}
.into());
return Err(CheckError::Failed(
ReceiptError::InvalidTimestamp {
received_timestamp: signed_receipt.message.timestamp_ns,
timestamp_min: min_timestamp_ns,
}
.into(),
));
}
Ok(())
}
Expand Down
4 changes: 3 additions & 1 deletion tap_core/src/receipt/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ pub enum ReceiptError {
#[error("Attempt to collect escrow failed")]
SubtractEscrowFailed,
#[error("Issue encountered while performing check: {0}")]
CheckFailedToComplete(String),
CheckFailure(String),
#[error("Retryable check error encountered: {0}")]
RetryableCheck(String),
}
20 changes: 11 additions & 9 deletions tap_core/src/receipt/received_receipt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use alloy::dyn_abi::Eip712Domain;

use super::checks::CheckError;
use super::{Receipt, ReceiptError, ReceiptResult, SignedReceipt};
use crate::receipt::state::{AwaitingReserve, Checking, Failed, ReceiptState, Reserved};
use crate::{
Expand Down Expand Up @@ -92,10 +93,10 @@ impl ReceiptWithState<Checking> {
pub async fn perform_checks(&mut self, checks: &[ReceiptCheck]) -> ReceiptResult<()> {
for check in checks {
// return early on an error
check
.check(self)
.await
.map_err(|e| ReceiptError::CheckFailedToComplete(e.to_string()))?;
check.check(self).await.map_err(|e| match e {
CheckError::Retryable(e) => ReceiptError::RetryableCheck(e.to_string()),
CheckError::Failed(e) => ReceiptError::CheckFailure(e.to_string()),
})?;
}
Ok(())
}
Expand All @@ -108,14 +109,15 @@ impl ReceiptWithState<Checking> {
pub async fn finalize_receipt_checks(
mut self,
checks: &[ReceiptCheck],
) -> ResultReceipt<AwaitingReserve> {
) -> Result<ResultReceipt<AwaitingReserve>, String> {
let all_checks_passed = self.perform_checks(checks).await;

if let Err(e) = all_checks_passed {
Err(self.perform_state_error(e))
if let Err(ReceiptError::RetryableCheck(e)) = all_checks_passed {
Err(e.to_string())
} else if let Err(e) = all_checks_passed {
Ok(Err(self.perform_state_error(e)))
} else {
let checked = self.perform_state_changes(AwaitingReserve);
Ok(checked)
Ok(Ok(checked))
}
}
}
Expand Down
84 changes: 81 additions & 3 deletions tap_core/tests/manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
use std::{
collections::HashMap,
str::FromStr,
sync::{Arc, RwLock},
sync::{atomic::AtomicBool, Arc, RwLock},
time::{SystemTime, UNIX_EPOCH},
};

use alloy::{dyn_abi::Eip712Domain, primitives::Address, signers::local::PrivateKeySigner};
use anyhow::anyhow;
use rstest::*;

fn get_current_timestamp_u64_ns() -> anyhow::Result<u64> {
Expand All @@ -24,8 +25,9 @@ use tap_core::{
},
rav::ReceiptAggregateVoucher,
receipt::{
checks::{CheckList, StatefulTimestampCheck},
Receipt,
checks::{Check, CheckError, CheckList, StatefulTimestampCheck},
state::Checking,
Receipt, ReceiptWithState,
},
signed_message::EIP712SignedMessage,
tap_eip712_domain,
Expand Down Expand Up @@ -530,3 +532,79 @@ async fn manager_create_rav_and_ignore_invalid_receipts(
//Rav Value corresponds only to value of one receipt
assert_eq!(expected_rav.valueAggregate, 20);
}

#[rstest]
#[tokio::test]
async fn test_retryable_checks(
allocation_ids: Vec<Address>,
domain_separator: Eip712Domain,
context: ContextFixture,
) {
struct RetryableCheck(Arc<AtomicBool>);

#[async_trait::async_trait]
impl Check for RetryableCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> Result<(), CheckError> {
// we want to fail only if nonce is 5 and if is create rav step
if self.0.load(std::sync::atomic::Ordering::SeqCst)
&& receipt.signed_receipt().message.nonce == 5
{
Err(CheckError::Retryable(anyhow!("Retryable error")))
} else {
Ok(())
}
}
}

let ContextFixture {
context,
checks,
escrow_storage,
signer,
..
} = context;

let is_create_rav = Arc::new(AtomicBool::new(false));

let mut checks: Vec<Arc<dyn Check + Send + Sync>> = checks.iter().cloned().collect();
checks.push(Arc::new(RetryableCheck(is_create_rav.clone())));

let manager = Manager::new(
domain_separator.clone(),
context.clone(),
CheckList::new(checks),
);

escrow_storage
.write()
.unwrap()
.insert(signer.address(), 999999);

let mut stored_signed_receipts = Vec::new();
for i in 0..10 {
let receipt = Receipt {
allocation_id: allocation_ids[0],
timestamp_ns: i + 1,
nonce: i,
value: 20u128,
};
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
stored_signed_receipts.push(signed_receipt.clone());
manager
.verify_and_store_receipt(signed_receipt)
.await
.unwrap();
}

is_create_rav.store(true, std::sync::atomic::Ordering::SeqCst);

let rav_request = manager.create_rav_request(0, None).await;

assert_eq!(
rav_request.expect_err("Didn't fail").to_string(),
tap_core::Error::ReceiptError(tap_core::receipt::ReceiptError::RetryableCheck(
"Retryable error".to_string()
))
.to_string()
);
}
2 changes: 2 additions & 0 deletions tap_core/tests/received_receipt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async fn partial_then_finalize_valid_receipt(

let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap();
let receipt = awaiting_escrow_receipt
.unwrap()
.check_and_reserve_escrow(&context, &domain_separator)
.await;
assert!(receipt.is_ok());
Expand Down Expand Up @@ -234,6 +235,7 @@ async fn standard_lifetime_valid_receipt(

let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap();
let receipt = awaiting_escrow_receipt
.unwrap()
.check_and_reserve_escrow(&context, &domain_separator)
.await;
assert!(receipt.is_ok());
Expand Down

0 comments on commit ff856d9

Please sign in to comment.