Skip to content

Commit

Permalink
req-resp: Refactor to move functionality to dedicated methods (#244)
Browse files Browse the repository at this point in the history
This PR refactors the `RequstResponse::run()` method to move away from
writing code in `loop { tokio::select! }`.
Instead, the functionality is moved to dedicated functions.

While at it, have remove the `async` from functions that were plain sync
and added comments to make the polling a bit clearer.

cc @paritytech/networking

---------

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>
  • Loading branch information
lexnv authored Sep 18, 2024
1 parent a68b713 commit c19c144
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 109 deletions.
1 change: 1 addition & 0 deletions src/protocol/request_response/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ impl From<SubstreamError> for RejectReason {
}

/// Request-response events.
#[derive(Debug)]
pub(super) enum InnerRequestResponseEvent {
/// Request received from remote
RequestReceived {
Expand Down
271 changes: 167 additions & 104 deletions src/protocol/request_response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ pub(crate) struct RequestResponseProtocol {
/// notifies the future that the request should be rejected by closing the substream.
pending_outbound_responses: FuturesUnordered<BoxFuture<'static, ()>>,

/// Pending inbound responses.
pending_inbound: FuturesUnordered<BoxFuture<'static, PendingRequest>>,

/// Pending outbound cancellation handles.
pending_outbound_cancels: HashMap<RequestId, oneshot::Sender<()>>,

/// Pending inbound responses.
pending_inbound: FuturesUnordered<BoxFuture<'static, PendingRequest>>,

/// Pending inbound requests.
pending_inbound_requests: SubstreamSet<(PeerId, RequestId), Substream>,

Expand Down Expand Up @@ -227,6 +227,12 @@ impl RequestResponseProtocol {

match self.pending_dials.remove(&peer) {
None => {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
"peer connected without pending dial",
);
entry.insert(PeerContext::new());
}
Some(context) => match self.service.open_substream(peer) {
Expand Down Expand Up @@ -706,14 +712,14 @@ impl RequestResponseProtocol {
}

/// Send request to remote peer.
async fn on_send_request(
fn on_send_request(
&mut self,
peer: PeerId,
request_id: RequestId,
request: Vec<u8>,
dial_options: DialOptions,
fallback: Option<(ProtocolName, Vec<u8>)>,
) -> crate::Result<()> {
) -> Result<(), RequestResponseError> {
tracing::trace!(
target: LOG_TARGET,
?peer,
Expand All @@ -735,13 +741,7 @@ impl RequestResponseProtocol {
"peer not connected and should not dial",
);

return self
.report_request_failure(
peer,
request_id,
RequestResponseError::NotConnected,
)
.await;
return Err(RequestResponseError::NotConnected);
}
DialOptions::Dial => match self.service.dial(&peer) {
Ok(_) => {
Expand All @@ -768,15 +768,9 @@ impl RequestResponseProtocol {
"failed to dial peer"
);

return self
.report_request_failure(
peer,
request_id,
RequestResponseError::Rejected(RejectReason::DialFailed(Some(
error,
))),
)
.await;
return Err(RequestResponseError::Rejected(RejectReason::DialFailed(
Some(error),
)));
}
},
}
Expand Down Expand Up @@ -806,12 +800,7 @@ impl RequestResponseProtocol {
"failed to open substream",
);

self.report_request_failure(
peer,
request_id,
RequestResponseError::Rejected(error.into()),
)
.await
return Err(RequestResponseError::Rejected(error.into()));
}
}
}
Expand Down Expand Up @@ -871,7 +860,7 @@ impl RequestResponseProtocol {
}

/// Cancel outbound request.
async fn on_cancel_request(&mut self, request_id: RequestId) -> crate::Result<()> {
fn on_cancel_request(&mut self, request_id: RequestId) -> crate::Result<()> {
tracing::trace!(target: LOG_TARGET, protocol = %self.protocol, ?request_id, "cancel outbound request");

match self.pending_outbound_cancels.remove(&request_id) {
Expand All @@ -889,6 +878,142 @@ impl RequestResponseProtocol {
}
}

/// Handles the service event.
async fn handle_service_event(&mut self, event: TransportEvent) {
match event {
TransportEvent::ConnectionEstablished { peer, .. } => {
if let Err(error) = self.on_connection_established(peer).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?error,
"failed to handle connection established",
);
}
}

TransportEvent::ConnectionClosed { peer } => {
self.on_connection_closed(peer).await;
}

TransportEvent::SubstreamOpened {
peer,
substream,
direction,
fallback,
..
} => match direction {
Direction::Inbound => {
if let Err(error) = self.on_inbound_substream(peer, fallback, substream).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?error,
"failed to handle inbound substream",
);
}
}
Direction::Outbound(substream_id) => {
let _ =
self.on_outbound_substream(peer, substream_id, substream, fallback).await;
}
},

TransportEvent::SubstreamOpenFailure { substream, error } => {
if let Err(error) = self.on_substream_open_failure(substream, error).await {
tracing::warn!(
target: LOG_TARGET,
protocol = %self.protocol,
?error,
"failed to handle substream open failure",
);
}
}

TransportEvent::DialFailure { peer, .. } => self.on_dial_failure(peer).await,
}
}

/// Handles the user command.
async fn handle_user_command(&mut self, command: RequestResponseCommand) {
match command {
RequestResponseCommand::SendRequest {
peer,
request_id,
request,
dial_options,
} => {
if let Err(error) =
self.on_send_request(peer, request_id, request, dial_options, None)
{
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to send request",
);

if let Err(error) = self.report_request_failure(peer, request_id, error).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to report request failure",
);
}
}
}
RequestResponseCommand::SendRequestWithFallback {
peer,
request_id,
request,
fallback,
dial_options,
} => {
if let Err(error) =
self.on_send_request(peer, request_id, request, dial_options, Some(fallback))
{
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to send request",
);

if let Err(error) = self.report_request_failure(peer, request_id, error).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to report request failure",
);
}
}
}
RequestResponseCommand::CancelRequest { request_id } => {
if let Err(error) = self.on_cancel_request(request_id) {
tracing::debug!(
target: LOG_TARGET,
protocol = %self.protocol,
?request_id,
?error,
"failed to cancel reqeuest",
);
}
}
}
}

/// Start [`RequestResponseProtocol`] event loop.
pub async fn run(mut self) {
tracing::debug!(target: LOG_TARGET, "starting request-response event loop");
Expand All @@ -899,48 +1024,16 @@ impl RequestResponseProtocol {
// responses to network behaviour so ensure that the commands operate on the most up to date information.
biased;

// Connection and substream events from the transport service.
event = self.service.next() => match event {
Some(TransportEvent::ConnectionEstablished { peer, .. }) => {
let _ = self.on_connection_established(peer).await;
}
Some(TransportEvent::ConnectionClosed { peer }) => {
self.on_connection_closed(peer).await;
}
Some(TransportEvent::SubstreamOpened {
peer,
substream,
direction,
fallback,
..
}) => match direction {
Direction::Inbound => {
if let Err(error) = self.on_inbound_substream(peer, fallback, substream).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?error,
"failed to handle inbound substream",
);
}
}
Direction::Outbound(substream_id) => {
let _ = self.on_outbound_substream(peer, substream_id, substream, fallback).await;
}
},
Some(TransportEvent::SubstreamOpenFailure { substream, error }) => {
if let Err(error) = self.on_substream_open_failure(substream, error).await {
tracing::warn!(
target: LOG_TARGET,
protocol = %self.protocol,
?error,
"failed to handle substream open failure",
);
}
Some(event) => self.handle_service_event(event).await,
None => {
tracing::debug!(target: LOG_TARGET, protocol = %self.protocol, "service has exited, exiting");
return
}
Some(TransportEvent::DialFailure { peer, .. }) => self.on_dial_failure(peer).await,
None => return,
},

// These are outbound requests waiting for the substream to produce a response.
event = self.pending_inbound.select_next_some(), if !self.pending_inbound.is_empty() => {
let (peer, request_id, fallback, event) = event;

Expand All @@ -957,7 +1050,11 @@ impl RequestResponseProtocol {

self.pending_outbound_cancels.remove(&request_id);
}

// These are inbound requests waiting for the user to respond, then for the substream to send the response.
_ = self.pending_outbound_responses.next(), if !self.pending_outbound_responses.is_empty() => {}

// Inbound requests that are moved to `pending_outbound_responses`.
event = self.pending_inbound_requests.next() => match event {
Some(((peer, request_id), message)) => {
if let Err(error) = self.on_inbound_request(peer, request_id, message).await {
Expand All @@ -973,48 +1070,14 @@ impl RequestResponseProtocol {
}
None => return,
},

// User commands.
command = self.command_rx.recv() => match command {
Some(command) => self.handle_user_command(command).await,
None => {
tracing::debug!(target: LOG_TARGET, protocol = %self.protocol, "user protocol has exited, exiting");
return
}
Some(command) => match command {
RequestResponseCommand::SendRequest { peer, request_id, request, dial_options } => {
if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, None).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to send request",
);
}
}
RequestResponseCommand::CancelRequest { request_id } => {
if let Err(error) = self.on_cancel_request(request_id).await {
tracing::debug!(
target: LOG_TARGET,
protocol = %self.protocol,
?request_id,
?error,
"failed to cancel reqeuest",
);
}
}
RequestResponseCommand::SendRequestWithFallback { peer, request_id, request, fallback, dial_options } => {
if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, Some(fallback)).await {
tracing::debug!(
target: LOG_TARGET,
?peer,
protocol = %self.protocol,
?request_id,
?error,
"failed to send request",
);
}
}
}
},
}
}
Expand Down
Loading

0 comments on commit c19c144

Please sign in to comment.