From 3885322844df02cffd4bc7493a6d8409bb1e815f Mon Sep 17 00:00:00 2001 From: ndefokou Date: Wed, 18 Dec 2024 16:34:31 +0100 Subject: [PATCH] fixe()integrating the circuit breaker on the different protocols --- .../protocols/forward/Cargo.toml | 1 + .../protocols/forward/src/handler.rs | 130 +-- .../protocols/forward/src/plugin.rs | 13 +- .../mediator-coordination/src/errors.rs | 3 + .../src/handler/stateful.rs | 774 ++++++++++-------- .../mediator-coordination/src/plugin.rs | 26 +- .../protocols/pickup/Cargo.toml | 1 + .../protocols/pickup/src/handler.rs | 635 ++++++++------ .../protocols/pickup/src/plugin.rs | 26 +- .../shared/src/CircuitBreaker.rs | 75 -- .../shared/src/circuit_breaker.rs | 90 ++ .../didcomm-messaging/shared/src/lib.rs | 2 +- .../didcomm-messaging/shared/src/state.rs | 4 +- 13 files changed, 1056 insertions(+), 724 deletions(-) delete mode 100644 crates/web-plugins/didcomm-messaging/shared/src/CircuitBreaker.rs create mode 100644 crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml b/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml index 85278439..043bf1aa 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/Cargo.toml @@ -17,6 +17,7 @@ thiserror.workspace = true didcomm = { workspace = true, features = ["uniffi"] } hyper = { workspace = true, features = ["full"] } axum = { workspace = true, features = ["macros"] } +tokio = "1.27.0" [dev-dependencies] keystore = { workspace = true, features = ["test-utils"] } diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs b/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs index 0b260527..7cb1aafd 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/src/handler.rs @@ -4,78 +4,84 @@ use didcomm::{AttachmentData, Message}; use mongodb::bson::doc; use serde_json::{json, Value}; use shared::{ + circuit_breaker::CircuitBreaker, repository::entity::{Connection, RoutedMessage}, retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, }; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Mutex; /// Mediator receives forwarded messages, extract the next field in the message body, and the attachments in the message /// then stores the attachment with the next field as key for pickup pub(crate) async fn mediator_forward_process( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, ForwardError> { - let AppStateRepository { - message_repository, - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or_else(|| ForwardError::InternalServerError)?; - - let circuit_breaker = state.circuit_breaker.clone(); - if circuit_breaker.is_open() { - return Err(ForwardError::CircuitOpen); - } - - let next = match checks(&message, connection_repository).await.ok() { - Some(next) => Ok(next), - None => Err(ForwardError::InternalServerError), - }; - - let attachments = message.attachments.unwrap_or_default(); - for attachment in attachments { - let attached = match attachment.data { - AttachmentData::Json { value: data } => data.json, - AttachmentData::Base64 { value: data } => json!(data.base64), - AttachmentData::Links { value: data } => json!(data.links), - }; - - let result = retry_async( - || { - let attached = attached.clone(); - let recipient_did = next.as_ref().unwrap().to_owned(); - - async move { - message_repository - .store(RoutedMessage { - id: None, - message: attached, - recipient_did, - }) - .await + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = Arc::clone(&state); + let message = message.clone(); + async move { + let AppStateRepository { + message_repository, + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or_else(|| ForwardError::InternalServerError)?; + + let next = match checks(&message, connection_repository).await.ok() { + Some(next) => Ok(next), + None => Err(ForwardError::InternalServerError), + }?; + + let attachments = message.attachments.unwrap_or_default(); + for attachment in attachments { + let attached = match attachment.data { + AttachmentData::Json { value: data } => data.json, + AttachmentData::Base64 { value: data } => json!(data.base64), + AttachmentData::Links { value: data } => json!(data.links), + }; + retry_async( + || { + let attached = attached.clone(); + let recipient_did = next.to_owned(); + + async move { + message_repository + .store(RoutedMessage { + id: None, + message: attached, + recipient_did, + }) + .await + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(1)), + ) + .await + .map_err(|_| ForwardError::InternalServerError)?; } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(1)), - ) + Ok::, ForwardError>(None) + } + }) .await; - match result { - Ok(_) => circuit_breaker.record_success(), - Err(_) => { - circuit_breaker.record_failure(); - return Err(ForwardError::InternalServerError); - } - }; + match result { + Some(Ok(None)) => Ok(None), + Some(Ok(Some(_))) => Err(ForwardError::InternalServerError), + Some(Err(err)) => Err(err), + None => Err(ForwardError::CircuitOpen), } - - Ok(None) } async fn checks( @@ -113,6 +119,7 @@ mod test { use keystore::Secrets; use serde_json::json; use shared::{ + circuit_breaker, repository::{ entity::Connection, tests::{MockConnectionRepository, MockMessagesRepository}, @@ -196,9 +203,16 @@ mod test { .await .expect("Unable unpack"); - let msg = mediator_forward_process(Arc::new(state.clone()), msg) - .await - .unwrap(); + // Wrap the CircuitBreaker in Arc and Mutex + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let msg: Option = mediator_forward_process( + Arc::new(state.clone()), + msg, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap(); println!("Mediator1 is forwarding message \n{:?}\n", msg); } diff --git a/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs index 75744505..86325bc2 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/forward/src/plugin.rs @@ -3,8 +3,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; pub struct RoutingProtocol; @@ -17,7 +18,13 @@ impl MessageHandler for ForwardHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::mediator_forward_process(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + // Pass the state, msg, and the circuit_breaker as arguments + crate::handler::mediator_forward_process(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs index 7e01a08f..2a15b917 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/errors.rs @@ -15,6 +15,8 @@ pub(crate) enum MediationError { UnexpectedMessageFormat, #[error("internal server error")] InternalServerError, + #[error("service unavailable")] + CircuitOpen, } impl IntoResponse for MediationError { @@ -26,6 +28,7 @@ impl IntoResponse for MediationError { MediationError::UncoordinatedSender => StatusCode::UNAUTHORIZED, MediationError::UnexpectedMessageFormat => StatusCode::BAD_REQUEST, MediationError::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, + MediationError::CircuitOpen => StatusCode::SERVICE_UNAVAILABLE, }; let body = Json(serde_json::json!({ diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs index 1ac5b86e..e02e718c 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/handler/stateful.rs @@ -19,19 +19,21 @@ use keystore::Secrets; use mongodb::bson::doc; use serde_json::json; use shared::{ + circuit_breaker::CircuitBreaker, midlw::ensure_transport_return_route_is_decorated_all, repository::entity::Connection, retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, - CircuitBreaker::CircuitBreaker, }; use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; use uuid::Uuid; /// Process a DIDComm mediate request pub(crate) async fn process_mediate_request( state: Arc, plain_message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { // This is to Check message type compliance ensure_jwm_type_is_mediation_request(&plain_message)?; @@ -44,140 +46,165 @@ pub(crate) async fn process_mediate_request( let sender_did = plain_message.from.as_ref().unwrap(); - // Retrieve repository to connection entities + // Acquire the CircuitBreaker lock + let mut cb = circuit_breaker.lock().await; - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - // If there is already mediation, send mediate deny - if let Some(_connection) = retry_async( - || { + // Wrap the process logic in the CircuitBreaker call + let result = cb + .call_async(|| { + let state = state.clone(); let sender_did = sender_did.clone(); - let connection_repository = connection_repository.clone(); + let mediator_did = mediator_did.clone(); async move { - connection_repository - .find_one_by(doc! { "client_did": sender_did }) + // Retrieve repository to connection entities + // Retrieve repository to connection entities + + // Retrieve repository to connection entities + + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + // If there is already mediation, send mediate deny + if let Some(_connection) = retry_async( + || { + let sender_did = sender_did.clone(); + let connection_repository = connection_repository.clone(); + + async move { + connection_repository + .find_one_by(doc! { "client_did": sender_did }) + .await + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(1)), + ) + .await + .map_err(|_| MediationError::InternalServerError)? + { + tracing::info!("Sending mediate deny."); + return Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + MEDIATE_DENY_2_0.to_string(), + json!(MediationDeny { + id: format!("urn:uuid:{}", Uuid::new_v4()), + message_type: MEDIATE_DENY_2_0.to_string(), + ..Default::default() + }), + ) + .to(sender_did.clone()) + .from(mediator_did.clone()) + .finalize(), + )); + } else { + // Issue mediate grant response + tracing::info!("Sending mediate grant."); + // Create routing, store it and send mediation grant + let (routing_did, auth_keys, agreem_keys) = + generate_did_peer(state.public_domain.to_string()); + + let AppStateRepository { keystore, .. } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + let diddoc = retry_async( + || { + let did_resolver = state.did_resolver.clone(); + let routing_did = routing_did.clone(); + + async move { did_resolver.resolve(&routing_did).await.map_err(|_| ()) } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) .await - } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(1)), - ) - .await - .map_err(|_| MediationError::InternalServerError)? - { - tracing::info!("Sending mediate deny."); - return Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - MEDIATE_DENY_2_0.to_string(), - json!(MediationDeny { - id: format!("urn:uuid:{}", Uuid::new_v4()), - message_type: MEDIATE_DENY_2_0.to_string(), - ..Default::default() - }), - ) - .to(sender_did.clone()) - .from(mediator_did.clone()) - .finalize(), - )); - } else { - /* Issue mediate grant response */ - tracing::info!("Sending mediate grant."); - // Create routing, store it and send mediation grant - let (routing_did, auth_keys, agreem_keys) = - generate_did_peer(state.public_domain.to_string()); - - let AppStateRepository { keystore, .. } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - let diddoc = retry_async( - || { - let did_resolver = state.did_resolver.clone(); - let routing_did = routing_did.clone(); - - async move { did_resolver.resolve(&routing_did).await.map_err(|_| ()) } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|err| { - tracing::error!("Failed to resolve DID: {:?}", err); - MediationError::InternalServerError - })? - .ok_or(MediationError::InternalServerError)?; - - let agreem_keys_jwk: Jwk = agreem_keys.try_into().unwrap(); - - let agreem_keys_secret = Secrets { - id: None, - kid: diddoc.key_agreement.get(0).unwrap().clone(), - secret_material: agreem_keys_jwk, - }; - - match keystore.store(agreem_keys_secret).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored agreement keys.") - } - Err(error) => tracing::error!("Error storing agreement keys: {:?}", error), - } + .map_err(|err| { + tracing::error!("Failed to resolve DID: {:?}", err); + MediationError::InternalServerError + })? + .ok_or(MediationError::InternalServerError)?; + + let agreem_keys_jwk: Jwk = agreem_keys.try_into().unwrap(); + + let agreem_keys_secret = Secrets { + id: None, + kid: diddoc.key_agreement.get(0).unwrap().clone(), + secret_material: agreem_keys_jwk, + }; + + match keystore.store(agreem_keys_secret).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored agreement keys.") + } + Err(error) => tracing::error!("Error storing agreement keys: {:?}", error), + } - let auth_keys_jwk: Jwk = auth_keys.try_into().unwrap(); + let auth_keys_jwk: Jwk = auth_keys.try_into().unwrap(); + + let auth_keys_secret = Secrets { + id: None, + kid: diddoc.authentication.get(0).unwrap().clone(), + secret_material: auth_keys_jwk, + }; + + match keystore.store(auth_keys_secret).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored authentication keys.") + } + Err(error) => { + tracing::error!("Error storing authentication keys: {:?}", error) + } + } - let auth_keys_secret = Secrets { - id: None, - kid: diddoc.authentication.get(0).unwrap().clone(), - secret_material: auth_keys_jwk, - }; + let mediation_grant = create_mediation_grant(&routing_did); + + let new_connection = Connection { + id: None, + client_did: sender_did.to_string(), + mediator_did: mediator_did.to_string(), + keylist: vec!["".to_string()], + routing_did: routing_did, + }; + + // Use store_one to store the sample connection + match connection_repository.store(new_connection).await { + Ok(_stored_connection) => { + tracing::info!("Successfully stored connection: ") + } + Err(error) => tracing::error!("Error storing connection: {:?}", error), + } - match keystore.store(auth_keys_secret).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored authentication keys.") + Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + mediation_grant.message_type.clone(), + json!(mediation_grant), + ) + .to(sender_did.clone()) + .from(mediator_did.clone()) + .finalize(), + )) + } } - Err(error) => tracing::error!("Error storing authentication keys: {:?}", error), - } - - let mediation_grant = create_mediation_grant(&routing_did); - - let new_connection = Connection { - id: None, - client_did: sender_did.to_string(), - mediator_did: mediator_did.to_string(), - keylist: vec!["".to_string()], - routing_did: routing_did, - }; + }) + .await; - // Use store_one to store the sample connection - match connection_repository.store(new_connection).await { - Ok(_stored_connection) => { - tracing::info!("Successfully stored connection: ") - } - Err(error) => tracing::error!("Error storing connection: {:?}", error), - } - - Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - mediation_grant.message_type.clone(), - json!(mediation_grant), - ) - .to(sender_did.clone()) - .from(mediator_did.clone()) - .finalize(), - )) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), } } @@ -230,6 +257,7 @@ fn generate_did_peer(service_endpoint: String) -> (String, Ed25519KeyPair, X2551 pub(crate) async fn process_plain_keylist_update_message( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { // Extract message sender @@ -242,214 +270,247 @@ pub(crate) async fn process_plain_keylist_update_message( let keylist_update_body: KeylistUpdateBody = serde_json::from_value(message.body) .map_err(|_| MediationError::UnexpectedMessageFormat)?; - // Retrieve repository to connection entities - - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; + let mut cb = circuit_breaker.lock().await; - // Find connection for this keylist update - - let connection = retry_async( - || { - let connection_repository = connection_repository.clone(); + let result = cb + .call_async(|| { + let state = state.clone(); let sender = sender.clone(); + let keylist_update_body = keylist_update_body.clone(); async move { - connection_repository - .find_one_by(doc! { "client_did": &sender }) - .await - .map_err(|_| ()) - } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|err| { - tracing::error!("Failed to find connection after retries: {:?}", err); - MediationError::InternalServerError - })? - .ok_or_else(|| MediationError::UncoordinatedSender)?; - - // Prepare handles to relevant collections - - let mut updated_keylist = connection.keylist.clone(); - let updates = keylist_update_body.updates; - - // Closure to check if a specific key is duplicated across commands - - let key_is_duplicate = |recipient_did| { - updates - .iter() - .filter(|e| &e.recipient_did == recipient_did) - .count() - > 1 - }; - - // Perform updates to persist - - let confirmations: Vec<_> = updates - .iter() - .map(|update| KeylistUpdateConfirmation { - recipient_did: update.recipient_did.clone(), - action: update.action.clone(), - result: { - if let KeylistUpdateAction::Unknown(_) = &update.action { - KeylistUpdateResult::ClientError - } else if key_is_duplicate(&update.recipient_did) { - KeylistUpdateResult::ClientError - } else { - match connection - .keylist + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + // Find connection for this keylist update + + let connection = retry_async( + || { + let connection_repository = connection_repository.clone(); + let sender = sender.clone(); + + async move { + connection_repository + .find_one_by(doc! { "client_did": &sender }) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|err| { + tracing::error!("Failed to find connection after retries: {:?}", err); + MediationError::InternalServerError + })? + .ok_or_else(|| MediationError::UncoordinatedSender)?; + + // Prepare handles to relevant collections + + let mut updated_keylist = connection.keylist.clone(); + let updates = keylist_update_body.updates; + + // Closure to check if a specific key is duplicated across commands + + let key_is_duplicate = |recipient_did| { + updates .iter() - .position(|x| x == &update.recipient_did) - { - Some(index) => match &update.action { - KeylistUpdateAction::Add => KeylistUpdateResult::NoChange, - KeylistUpdateAction::Remove => { - updated_keylist.swap_remove(index); - KeylistUpdateResult::Success + .filter(|e| &e.recipient_did == recipient_did) + .count() + > 1 + }; + + // Process keylist updates + let confirmations: Vec<_> = updates + .iter() + .map(|update| KeylistUpdateConfirmation { + recipient_did: update.recipient_did.clone(), + action: update.action.clone(), + result: { + if let KeylistUpdateAction::Unknown(_) = &update.action { + KeylistUpdateResult::ClientError + } else if key_is_duplicate(&update.recipient_did) { + KeylistUpdateResult::ClientError + } else { + match connection + .keylist + .iter() + .position(|x| x == &update.recipient_did) + { + Some(index) => match &update.action { + KeylistUpdateAction::Add => KeylistUpdateResult::NoChange, + KeylistUpdateAction::Remove => { + updated_keylist.swap_remove(index); + KeylistUpdateResult::Success + } + KeylistUpdateAction::Unknown(_) => unreachable!(), + }, + None => match &update.action { + KeylistUpdateAction::Add => { + updated_keylist.push(update.recipient_did.clone()); + KeylistUpdateResult::Success + } + KeylistUpdateAction::Remove => { + KeylistUpdateResult::NoChange + } + KeylistUpdateAction::Unknown(_) => unreachable!(), + }, + } } - KeylistUpdateAction::Unknown(_) => unreachable!(), }, - None => match &update.action { - KeylistUpdateAction::Add => { - updated_keylist.push(update.recipient_did.clone()); - KeylistUpdateResult::Success + }) + .collect(); + + let confirmations = match connection_repository + .update(Connection { + keylist: updated_keylist, + ..connection + }) + .await + { + Ok(_) => confirmations, + Err(_) => confirmations + .into_iter() + .map(|mut confirmation| { + if confirmation.result != KeylistUpdateResult::ClientError { + confirmation.result = KeylistUpdateResult::ServerError } - KeylistUpdateAction::Remove => KeylistUpdateResult::NoChange, - KeylistUpdateAction::Unknown(_) => unreachable!(), - }, - } - } - }, - }) - .collect(); - - // Persist updated keylist, update confirmations if server error - let confirmations = match connection_repository - .update(Connection { - keylist: updated_keylist, - ..connection + confirmation + }) + .collect(), + }; + + // Build response + + let mediator_did = &state.diddoc.id; + + Ok(Some( + Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + KEYLIST_UPDATE_RESPONSE_2_0.to_string(), + json!(KeylistUpdateResponseBody { + updated: confirmations + }), + ) + .to(sender) + .from(mediator_did.to_owned()) + .finalize(), + )) + } }) - .await - { - Ok(_) => confirmations, - Err(_) => confirmations - .into_iter() - .map(|mut confirmation| { - if confirmation.result != KeylistUpdateResult::ClientError { - confirmation.result = KeylistUpdateResult::ServerError - } - - confirmation - }) - .collect(), - }; - - // Build response - - let mediator_did = &state.diddoc.id; + .await; - Ok(Some( - Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - KEYLIST_UPDATE_RESPONSE_2_0.to_string(), - json!(KeylistUpdateResponseBody { - updated: confirmations - }), - ) - .to(sender) - .from(mediator_did.to_owned()) - .finalize(), - )) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), + } } pub(crate) async fn process_plain_keylist_query_message( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, MediationError> { println!("Processing keylist query..."); let sender = message .from .expect("unpacking middleware failed to prevent anonymous senders"); - let AppStateRepository { - connection_repository, - .. - } = state - .repository - .as_ref() - .ok_or(MediationError::InternalServerError)?; - - let connection = retry_async( - || { - let connection_repository = connection_repository.clone(); + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); let sender = sender.clone(); async move { - connection_repository - .find_one_by(doc! { "client_did": &sender }) - .await - .map_err(|_| ()) - } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|err| { - tracing::error!("Failed to find connection after retries: {:?}", err); - MediationError::InternalServerError - })? - .ok_or_else(|| MediationError::UncoordinatedSender)?; - - println!("keylist: {:?}", connection); - - let keylist_entries = connection - .keylist - .iter() - .map(|key| KeylistEntry { - recipient_did: key.clone(), - }) - .collect::>(); - - let body = KeylistBody { - keys: keylist_entries, - pagination: None, - }; + let AppStateRepository { + connection_repository, + .. + } = state + .repository + .as_ref() + .ok_or(MediationError::InternalServerError)?; + + let connection = retry_async( + || { + let connection_repository = connection_repository.clone(); + let sender = sender.clone(); + + async move { + connection_repository + .find_one_by(doc! { "client_did": &sender }) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|err| { + tracing::error!("Failed to find connection after retries: {:?}", err); + MediationError::InternalServerError + })? + .ok_or_else(|| MediationError::UncoordinatedSender)?; + + println!("keylist: {:?}", connection); + + let keylist_entries = connection + .keylist + .iter() + .map(|key| KeylistEntry { + recipient_did: key.clone(), + }) + .collect::>(); + + let body = KeylistBody { + keys: keylist_entries, + pagination: None, + }; + + let keylist_object = Keylist { + id: format!("urn:uuid:{}", Uuid::new_v4()), + message_type: KEYLIST_2_0.to_string(), + body: body, + additional_properties: None, + }; - let keylist_object = Keylist { - id: format!("urn:uuid:{}", Uuid::new_v4()), - message_type: KEYLIST_2_0.to_string(), - body: body, - additional_properties: None, - }; + let mediator_did = &state.diddoc.id; - let mediator_did = &state.diddoc.id; + let message = Message::build( + format!("urn:uuid:{}", Uuid::new_v4()), + KEYLIST_2_0.to_string(), + json!(keylist_object), + ) + .to(sender.clone()) + .from(mediator_did.clone()) + .finalize(); - let message = Message::build( - format!("urn:uuid:{}", Uuid::new_v4()), - KEYLIST_2_0.to_string(), - json!(keylist_object), - ) - .to(sender.clone()) - .from(mediator_did.clone()) - .finalize(); + println!("message: {:?}", message); - println!("message: {:?}", message); + Ok(Some(message)) + } + }) + .await; - Ok(Some(message)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(MediationError::CircuitOpen), + } } #[cfg(test)] @@ -457,7 +518,8 @@ mod tests { use super::*; use shared::{ - repository::tests::MockConnectionRepository, utils::tests_utils::tests as global, + circuit_breaker, repository::tests::MockConnectionRepository, + utils::tests_utils::tests as global, }; #[allow(clippy::needless_update)] @@ -491,11 +553,17 @@ mod tests { .from(global::_edge_did()) .finalize(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + // Process request - let response = process_plain_keylist_query_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_query_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, KEYLIST_2_0); assert_eq!(response.from.unwrap(), global::_mediator_did(&state)); @@ -515,10 +583,16 @@ mod tests { .from("did:example:uncoordinated_sender".to_string()) .finalize(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + // Process request - let err = process_plain_keylist_query_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let err = process_plain_keylist_query_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error for uncoordinated sender assert_eq!(err, MediationError::UncoordinatedSender,); } @@ -551,10 +625,16 @@ mod tests { // Process request - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); let response = response; // Assert metadata @@ -650,10 +730,16 @@ mod tests { // Process request - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates assert_eq!( @@ -713,11 +799,16 @@ mod tests { .finalize(); // Process request + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates assert_eq!( @@ -773,11 +864,16 @@ mod tests { .finalize(); // Process request + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); - let response = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap() - .expect("Response should not be None"); + let response = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); // Assert updates @@ -811,9 +907,15 @@ mod tests { .finalize(); // Process request - let err = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let err = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error assert_eq!(err, MediationError::UnexpectedMessageFormat,); @@ -857,9 +959,15 @@ mod tests { .finalize(); // Process request - let err = process_plain_keylist_update_message(Arc::clone(&state), message) - .await - .unwrap_err(); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let err = process_plain_keylist_update_message( + Arc::clone(&state), + message, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap_err(); // Assert issued error assert_eq!(err, MediationError::UncoordinatedSender,); diff --git a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs index 3aab4d69..978da015 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/mediator-coordination/src/plugin.rs @@ -3,8 +3,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; pub struct MediatorCoordinationProtocol; @@ -19,7 +20,12 @@ impl MessageHandler for MediateRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_mediate_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_mediate_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -32,7 +38,12 @@ impl MessageHandler for KeylistUpdateHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_plain_keylist_update_message(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_plain_keylist_update_message(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -45,7 +56,12 @@ impl MessageHandler for KeylistQueryHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::stateful::process_plain_keylist_query_message(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::stateful::process_plain_keylist_query_message(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml b/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml index 97694183..5b8fd91c 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/Cargo.toml @@ -18,6 +18,7 @@ thiserror.workspace = true async-trait.workspace = true uuid = { workspace = true, features = ["v4"] } axum = { workspace = true, features = ["macros"] } +tokio = "1.27.0" [dev-dependencies] hyper = "0.14.27" diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs index e5f281fd..36778e81 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/handler.rs @@ -10,237 +10,320 @@ use didcomm::{Attachment, Message, MessageBuilder}; use mongodb::bson::{doc, oid::ObjectId}; use serde_json::Value; use shared::{ + circuit_breaker::CircuitBreaker, midlw::ensure_transport_return_route_is_decorated_all, repository::entity::{Connection, RoutedMessage}, retry::{retry_async, RetryOptions}, state::{AppState, AppStateRepository}, }; use std::{str::FromStr, sync::Arc, time::Duration}; +use tokio::sync::Mutex; use uuid::Uuid; // Process pickup status request pub(crate) async fn handle_status_request( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; let mediator_did = &state.diddoc.id; - let recipient_did = message - .body - .get("recipient_did") - .and_then(|val| val.as_str()); let sender_did = sender_did(&message)?; - let repository = repository(state.clone())?; - - let connection = retry_async( - || { - let repository = repository.clone(); - async move { client_connection(&repository, sender_did).await } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|_| PickupError::InternalError("Failed to retrieve client connection".to_owned()))?; - - let message_count = count_messages(repository, recipient_did, connection).await?; - - let id = Uuid::new_v4().urn().to_string(); - let response_builder: MessageBuilder = StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - recipient_did, - message_count, - live_delivery: Some(false), - ..Default::default() - }, - } - .into(); + let mut cb = circuit_breaker.lock().await; - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); + let result = cb + .call_async(|| { + let state = state.clone(); + let message = message.clone(); + let circuit_breaker = circuit_breaker.clone(); + async move { + let recipient_did = message + .body + .get("recipient_did") + .and_then(|val| val.as_str()); + + let repository = repository(state.clone())?; + + let connection = retry_async( + || { + let repository = repository.clone(); + async move { client_connection(&repository, sender_did).await } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError("Failed to retrieve client connection".to_owned()) + })?; + + // Pass `recipient_did` to count_messages, allowing it to handle `None` + let message_count = count_messages( + repository, + recipient_did, + connection, + circuit_breaker.clone(), + ) + .await?; + + let id = Uuid::new_v4().urn().to_string(); + let response_builder: MessageBuilder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + recipient_did: recipient_did.to_owned(), + message_count, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) + } + }) + .await; - Ok(Some(response)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } // Process pickup delivery request pub(crate) async fn handle_delivery_request( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; let mediator_did = &state.diddoc.id; - let recipient_did = message - .body - .get("recipient_did") - .and_then(|val| val.as_str()); + let sender_did = sender_did(&message)?; let message_body = message.body.clone(); - let limit = retry_async( - || { + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { + let state = state.clone(); let message_body = message_body.clone(); + let circuit_breaker = circuit_breaker.clone(); async move { - message_body - .get("limit") - .and_then(Value::as_u64) - .ok_or_else(|| { - PickupError::MalformedRequest("Invalid \"limit\" specifier".to_owned()) - }) + let recipient_did = message_body.get("recipient_did").and_then(Value::as_str); + + let limit = retry_async( + || { + let message_body = message_body.clone(); + async move { + message_body + .get("limit") + .and_then(Value::as_u64) + .ok_or_else(|| { + PickupError::MalformedRequest( + "Invalid \"limit\" specifier".to_owned(), + ) + }) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await?; + + let repository = repository(state.clone())?; + let connection = retry_async( + || { + let repository = repository.clone(); + async move { client_connection(&repository, sender_did).await } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError("Failed to retrieve client connection".to_owned()) + })?; + + let messages = messages( + repository, + recipient_did, + connection, + limit as usize, + circuit_breaker.clone(), + ) + .await?; + + let response_builder: MessageBuilder; + let id = Uuid::new_v4().urn().to_string(); + + if messages.is_empty() { + response_builder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + recipient_did, + message_count: 0, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + } else { + let mut attachments: Vec = Vec::with_capacity(messages.len()); + + for message in messages { + let attached = Attachment::json(message.message) + .id(message.id.map(|id| id.to_string()).ok_or_else(|| { + PickupError::InternalError( + "Failed to load requested messages. Please try again later." + .to_owned(), + ) + })?) + .finalize(); + + attachments.push(attached); + } + + response_builder = DeliveryResponse { + id: id.as_str(), + thid: id.as_str(), + type_: MESSAGE_DELIVERY_3_0, + body: BodyDeliveryResponse { recipient_did }, + attachments, + } + .into(); + } + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await?; - - let repository = repository(state.clone())?; - let connection = retry_async( - || { - let repository = repository.clone(); - async move { client_connection(&repository, sender_did).await } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|_| PickupError::InternalError("Failed to retrieve client connection".to_owned()))?; - - let messages = messages(repository, recipient_did, connection, limit as usize).await?; - - let response_builder: MessageBuilder; - let id = Uuid::new_v4().urn().to_string(); - - if messages.is_empty() { - response_builder = StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - recipient_did, - message_count: 0, - live_delivery: Some(false), - ..Default::default() - }, - } - .into(); - } else { - let mut attachments: Vec = Vec::with_capacity(messages.len()); - - for message in messages { - let attached = Attachment::json(message.message) - .id(message.id.map(|id| id.to_string()).ok_or_else(|| { - PickupError::InternalError( - "Failed to load requested messages. Please try again later.".to_owned(), - ) - })?) - .finalize(); - - attachments.push(attached); - } + }) + .await; - response_builder = DeliveryResponse { - id: id.as_str(), - thid: id.as_str(), - type_: MESSAGE_DELIVERY_3_0, - body: BodyDeliveryResponse { recipient_did }, - attachments, - } - .into(); + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), } - - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); - - Ok(Some(response)) } // Process pickup messages acknowledgement pub(crate) async fn handle_message_acknowledgement( state: Arc, message: Message, + circuit_breaker: Arc>, ) -> Result, PickupError> { // Validate the return_route header ensure_transport_return_route_is_decorated_all(&message) .map_err(|_| PickupError::MalformedRequest("Missing return_route header".to_owned()))?; - let mediator_did = &state.diddoc.id; - let repository = repository(state.clone())?; - let sender_did = sender_did(&message)?; - let connection = client_connection(&repository, sender_did).await?; - - // Get the message id list - let message_id_list = message - .body - .get("message_id_list") - .and_then(|v| v.as_array()) - .map(|a| a.iter().filter_map(|v| v.as_str()).collect::>()) - .ok_or_else(|| { - PickupError::MalformedRequest("Invalid \"message_id_list\" specifier".to_owned()) - })?; - - for id in message_id_list { - let msg_id = ObjectId::from_str(id) - .map_err(|_| PickupError::MalformedRequest(format!("Invalid message id: {id}")))?; - - retry_async( - || { - let message_repository = repository.message_repository.clone(); - let msg_id = msg_id.clone(); - - async move { message_repository.delete_one(msg_id).await.map_err(|_| ()) } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(2)), - ) - .await - .map_err(|_| { - PickupError::InternalError( - "Failed to process the request. Please try again later.".to_owned(), - ) - })?; - } - - let message_count = count_messages(repository, None, connection).await?; - - let id = Uuid::new_v4().urn().to_string(); - let response_builder: MessageBuilder = StatusResponse { - id: id.as_str(), - type_: STATUS_RESPONSE_3_0, - body: BodyStatusResponse { - message_count, - live_delivery: Some(false), - ..Default::default() - }, - } - .into(); + // Acquire the CircuitBreaker lock + let mut cb = circuit_breaker.lock().await; - let response = response_builder - .to(sender_did.to_owned()) - .from(mediator_did.to_owned()) - .finalize(); + // Wrap the message acknowledgement logic in the CircuitBreaker call + let result = cb + .call_async(|| { + let state = state.clone(); + let message = message.clone(); + let circuit_breaker = circuit_breaker.clone(); + async move { + let mediator_did = &state.diddoc.id; + let repository = repository(state.clone())?; + let sender_did = sender_did(&message)?; + let connection = client_connection(&repository, sender_did).await?; + + // Get the message ID list + let message_id_list = message + .body + .get("message_id_list") + .and_then(|v| v.as_array()) + .map(|a| a.iter().filter_map(|v| v.as_str()).collect::>()) + .ok_or_else(|| { + PickupError::MalformedRequest( + "Invalid \"message_id_list\" specifier".to_owned(), + ) + })?; + + for id in message_id_list { + let msg_id = ObjectId::from_str(id).map_err(|_| { + PickupError::MalformedRequest(format!("Invalid message id: {id}")) + })?; + + retry_async( + || { + let message_repository = repository.message_repository.clone(); + let msg_id = msg_id.clone(); + + async move { message_repository.delete_one(msg_id).await.map_err(|_| ()) } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError( + "Failed to process the request. Please try again later.".to_owned(), + ) + })?; + } + + let message_count = + count_messages(repository, None, connection, circuit_breaker).await?; + + let id = Uuid::new_v4().urn().to_string(); + let response_builder: MessageBuilder = StatusResponse { + id: id.as_str(), + type_: STATUS_RESPONSE_3_0, + body: BodyStatusResponse { + message_count, + live_delivery: Some(false), + ..Default::default() + }, + } + .into(); + + let response = response_builder + .to(sender_did.to_owned()) + .from(mediator_did.to_owned()) + .finalize(); + + Ok(Some(response)) + } + }) + .await; - Ok(Some(response)) + match result { + Some(Ok(response)) => Ok(response), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } // Process live delivery change request @@ -288,32 +371,50 @@ async fn count_messages( repository: AppStateRepository, recipient_did: Option<&str>, connection: Connection, + circuit_breaker: Arc>, ) -> Result { let recipients = recipients(recipient_did, &connection); - retry_async( - || { + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { let message_repository = repository.message_repository.clone(); let recipients = recipients.clone(); async move { - message_repository - .count_by(doc! { "recipient_did": { "$in": recipients } }) - .await - .map_err(|_| ()) + retry_async( + || { + let message_repository = message_repository.clone(); + let recipients = recipients.clone(); + + async move { + message_repository + .count_by(doc! { "recipient_did": { "$in": recipients } }) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(2)), + ) + .await + .map_err(|_| { + PickupError::InternalError( + "Failed to process the request. Please try again later.".to_owned(), + ) + }) } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(1)), - ) - .await - .map_err(|_| { - PickupError::InternalError( - "Failed to process the request. Please try again later.".to_owned(), - ) - }) + }) + .await; + + match result { + Some(Ok(count)) => Ok(count), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } async fn messages( @@ -321,35 +422,53 @@ async fn messages( recipient_did: Option<&str>, connection: Connection, limit: usize, + circuit_breaker: Arc>, ) -> Result, PickupError> { let recipients = recipients(recipient_did, &connection); - retry_async( - || { + let mut cb = circuit_breaker.lock().await; + + let result = cb + .call_async(|| { let message_repository = repository.message_repository.clone(); let recipients = recipients.clone(); async move { - message_repository - .find_all_by( - doc! { "recipient_did": { "$in": recipients } }, - Some(limit as i64), + retry_async( + || { + let message_repository = message_repository.clone(); + let recipients = recipients.clone(); + + async move { + message_repository + .find_all_by( + doc! { "recipient_did": { "$in": recipients } }, + Some(limit as i64), + ) + .await + .map_err(|_| ()) + } + }, + RetryOptions::new() + .retries(5) + .exponential_backoff(Duration::from_millis(100)) + .max_delay(Duration::from_secs(1)), + ) + .await + .map_err(|_| { + PickupError::InternalError( + "Failed to retrieve messages. Please try again later.".to_owned(), ) - .await - .map_err(|_| ()) + }) } - }, - RetryOptions::new() - .retries(5) - .exponential_backoff(Duration::from_millis(100)) - .max_delay(Duration::from_secs(1)), - ) - .await - .map_err(|_| { - PickupError::InternalError( - "Failed to process the request. Please try again later.".to_owned(), - ) - }) + }) + .await; + + match result { + Some(Ok(messages)) => Ok(messages), + Some(Err(err)) => Err(err), + None => Err(PickupError::CircuitOpen), + } } #[inline] @@ -407,6 +526,7 @@ mod tests { }; use serde_json::json; use shared::{ + circuit_breaker, repository::tests::{MockConnectionRepository, MockMessagesRepository}, utils::tests_utils::tests as global, }; @@ -471,10 +591,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!(response.from.unwrap(), global::_mediator_did(&state)); @@ -500,10 +625,16 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!( response.body, @@ -527,10 +658,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_status_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_status_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!( response.body, @@ -553,11 +689,13 @@ mod tests { .from("did:key:invalid".to_owned()) .finalize(); - let error = handle_status_request(state, invalid_request) + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let error = handle_status_request(state, invalid_request, Arc::new(circuit_breaker.into())) .await .unwrap_err(); assert_eq!(error.to_string(), "Failed to retrieve client connection"); + // assert_eq!(error, PickupError::MissingClientConnection); } #[tokio::test] @@ -574,10 +712,15 @@ mod tests { .from(global::_edge_did()) .finalize(); - let response = handle_delivery_request(Arc::clone(&state), request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request( + Arc::clone(&state), + request, + Arc::new(circuit_breaker.into()), + ) + .await + .unwrap() + .expect("Response should not be None"); let expected_attachments = vec![ Attachment::json(json!("test1")) @@ -619,7 +762,8 @@ mod tests { // When the specified recipient did is not in the keylist, // it should return a status response with a message count of 0 - let response = handle_delivery_request(state, request) + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -648,7 +792,9 @@ mod tests { // When the limit is set to 0, it should return all the messages in the queue // and since the recipient did is not specified, it should return the messages // for all the dids in the keylist for that sender connection - let response = handle_delivery_request(state, request) + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -688,7 +834,9 @@ mod tests { // Since the recipient did is not specified, it should return the messages // for all the dids in the keylist for that sender connection (2 in this case) // The limit is set to 1 so it should return the first message in the queue - let response = handle_delivery_request(state, request) + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = handle_delivery_request(state, request, Arc::new(circuit_breaker.into())) .await .unwrap() .expect("Response should not be None"); @@ -719,10 +867,13 @@ mod tests { .finalize(); // Should return 2 since these ids are not associated with any message - let response = handle_message_acknowledgement(state, request) - .await - .unwrap() - .expect("Response should not be None"); + + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = + handle_message_acknowledgement(state, request, Arc::new(circuit_breaker.into())) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!( @@ -747,10 +898,12 @@ mod tests { // Should return 1 since one id in the list is associated // to the first message in the queue and then will be deleted - let response = handle_message_acknowledgement(state, request) - .await - .unwrap() - .expect("Response should not be None"); + let circuit_breaker = circuit_breaker::CircuitBreaker::new(3, Duration::from_secs(3)); + let response = + handle_message_acknowledgement(state, request, Arc::new(circuit_breaker.into())) + .await + .unwrap() + .expect("Response should not be None"); assert_eq!(response.type_, STATUS_RESPONSE_3_0); assert_eq!( diff --git a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs index c40df56e..1d312070 100644 --- a/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs +++ b/crates/web-plugins/didcomm-messaging/protocols/pickup/src/plugin.rs @@ -5,8 +5,9 @@ use async_trait::async_trait; use axum::response::{IntoResponse, Response}; use didcomm::Message; use message_api::{MessageHandler, MessagePlugin, MessageRouter}; -use shared::state::AppState; -use std::sync::Arc; +use shared::{circuit_breaker::CircuitBreaker, state::AppState}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::Mutex; pub struct PickupProtocol; @@ -22,7 +23,12 @@ impl MessageHandler for StatusRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_status_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_status_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -35,7 +41,12 @@ impl MessageHandler for DeliveryRequestHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_delivery_request(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_delivery_request(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } @@ -48,7 +59,12 @@ impl MessageHandler for MessageReceivedHandler { state: Arc, msg: Message, ) -> Result, Response> { - crate::handler::handle_message_acknowledgement(state, msg) + let circuit_breaker = Arc::new(Mutex::new(CircuitBreaker::new( + 2, + Duration::from_millis(5000), + ))); + + crate::handler::handle_message_acknowledgement(state, msg, circuit_breaker) .await .map_err(|e| e.into_response()) } diff --git a/crates/web-plugins/didcomm-messaging/shared/src/CircuitBreaker.rs b/crates/web-plugins/didcomm-messaging/shared/src/CircuitBreaker.rs deleted file mode 100644 index 836a11d3..00000000 --- a/crates/web-plugins/didcomm-messaging/shared/src/CircuitBreaker.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Mutex, -}; -use std::time::{Duration, Instant}; - -#[derive(Debug)] -pub struct CircuitBreaker { - state: AtomicBool, // true = Open, false = Closed - failure_count: AtomicUsize, - last_failure_time: Mutex>, - threshold: usize, - reset_timeout: Duration, -} - -impl CircuitBreaker { - /// Creating a new CircuitBreaker with the given failure threshold and reset timeout. - pub fn new(threshold: usize, reset_timeout: Duration) -> Self { - Self { - state: AtomicBool::new(false), - failure_count: AtomicUsize::new(0), - last_failure_time: Mutex::new(None), - threshold, - reset_timeout, - } - } - - pub fn is_open(&self) -> bool { - if self.state.load(Ordering::Relaxed) { - let mut last_failure_time = self.last_failure_time.lock().unwrap(); - if let Some(last_time) = *last_failure_time { - if last_time.elapsed() > self.reset_timeout { - self.state.store(false, Ordering::Relaxed); - self.failure_count.store(0, Ordering::Relaxed); - *last_failure_time = None; - return false; - } - } - true - } else { - false - } - } - - pub fn record_failure(&self) { - let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1; - if failures >= self.threshold { - self.state.store(true, Ordering::Relaxed); - let mut last_failure_time = self.last_failure_time.lock().unwrap(); - *last_failure_time = Some(Instant::now()); - } - } - - pub fn record_success(&self) { - self.failure_count.store(0, Ordering::Relaxed); - self.state.store(false, Ordering::Relaxed); - } - - pub fn call(&self, f: F) -> Result>, String> - where - F: FnOnce() -> Result, - { - if self.is_open() { - return Ok(None); - } - - let result = f(); - match result { - Ok(_) => self.record_success(), - Err(_) => self.record_failure(), - } - - Ok(Some(result)) - } -} diff --git a/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs b/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs new file mode 100644 index 00000000..2238cd16 --- /dev/null +++ b/crates/web-plugins/didcomm-messaging/shared/src/circuit_breaker.rs @@ -0,0 +1,90 @@ +use std::time::{Duration, Instant}; + +#[derive(Debug)] +enum State { + // The circuit breaker is closed and allowing requests + // to pass through + Closed, + // The circuit breaker is open and blocking requests + Open, + // The circuit breaker is half-open and allowing a limited + // number of requests to pass through + HalfOpen, +} + +pub struct CircuitBreaker { + state: State, + // The duration to wait before transitioning from the + // open state to the half-open state + trip_timeout: Duration, + // The maximum number of requests allowed through in + // the closed state + max_failures: usize, + // The number of consecutive failures in the closed + // state + consecutive_failures: usize, + // The time when the circuit breaker transitioned to the + // open state + opened_at: Option, +} + +impl CircuitBreaker { + pub fn new(max_failures: usize, trip_timeout: Duration) -> CircuitBreaker { + CircuitBreaker { + state: State::Closed, + max_failures, + trip_timeout, + consecutive_failures: 0, + opened_at: None, + } + } + + pub async fn call_async(&mut self, f: F) -> Option> + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + match self.state { + State::Closed => { + if self.consecutive_failures < self.max_failures { + let result = f().await; + if result.is_err() { + self.record_failure(); + } + Some(result) + } else { + self.opened_at = Some(Instant::now()); + self.state = State::Open; + self.consecutive_failures = 0; + None + } + } + State::Open => { + if let Some(opened_at) = self.opened_at { + if Instant::now().duration_since(opened_at) >= self.trip_timeout { + self.state = State::HalfOpen; + self.opened_at = None; + } + } + None + } + State::HalfOpen => { + let result = f().await; + if result.is_err() { + self.state = State::Open; + } else { + self.state = State::Closed; + } + Some(result) + } + } + } + + fn record_failure(&mut self) { + match self.state { + State::Closed => self.consecutive_failures += 1, + State::Open => (), + State::HalfOpen => self.consecutive_failures += 1, + } + } +} diff --git a/crates/web-plugins/didcomm-messaging/shared/src/lib.rs b/crates/web-plugins/didcomm-messaging/shared/src/lib.rs index a777ec3a..77ffa0c8 100644 --- a/crates/web-plugins/didcomm-messaging/shared/src/lib.rs +++ b/crates/web-plugins/didcomm-messaging/shared/src/lib.rs @@ -1,4 +1,4 @@ -pub mod CircuitBreaker; +pub mod circuit_breaker; pub mod errors; pub mod midlw; pub mod repository; diff --git a/crates/web-plugins/didcomm-messaging/shared/src/state.rs b/crates/web-plugins/didcomm-messaging/shared/src/state.rs index 38e221fa..34a3f389 100644 --- a/crates/web-plugins/didcomm-messaging/shared/src/state.rs +++ b/crates/web-plugins/didcomm-messaging/shared/src/state.rs @@ -4,9 +4,7 @@ use keystore::Secrets; use std::{sync::Arc, time::Duration}; use crate::{ - repository::entity::{Connection, RoutedMessage}, - utils::resolvers::{LocalDIDResolver, LocalSecretsResolver}, - CircuitBreaker::CircuitBreaker, + circuit_breaker::CircuitBreaker, repository::entity::{Connection, RoutedMessage}, utils::resolvers::{LocalDIDResolver, LocalSecretsResolver} }; #[derive(Clone)]