Skip to content

Commit

Permalink
Merge pull request #377 from nervosnetwork/fix-session-protocol-order…
Browse files Browse the repository at this point in the history
…-on-open

fix: session open protocol open order
  • Loading branch information
driftluo authored Nov 20, 2024
2 parents 2aa6dd4 + 0394ae6 commit 8cf4f13
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 72 deletions.
188 changes: 117 additions & 71 deletions tentacle/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,12 @@ where

let _ignore = inner
.handle_sender
.send(ServiceEventAndError::Event(ServiceEvent::ListenStarted {
address: listen_address.clone(),
}))
.send(
ServiceEvent::ListenStarted {
address: listen_address.clone(),
}
.into(),
)
.await;
#[cfg(feature = "upnp")]
if let Some(client) = inner.igd_client.as_mut() {
Expand Down Expand Up @@ -270,8 +273,16 @@ where

while let Some(s) = self.recv.next().await {
match s {
ServiceEventAndError::Event(e) => {
self.handle.handle_event(&mut self.service_context, e).await
ServiceEventAndError::Event {
event,
wait_response,
} => {
self.handle
.handle_event(&mut self.service_context, event)
.await;
if let Some(tx) = wait_response {
let _ignore = tx.send(());
}
}
ServiceEventAndError::Error(e) => {
self.handle.handle_error(&mut self.service_context, e).await
Expand Down Expand Up @@ -705,6 +716,18 @@ where
self.future_task_sender.clone(),
);

// session open event must be notified first, and then the protocol is enabled
let (tx, rx) = futures::channel::oneshot::channel();
let _ignore = self
.handle_sender
.send(
Into::<ServiceEventAndError>::into(ServiceEvent::SessionOpen { session_context })
.wait_response(tx),
)
.await;
// Don't care about it's drop or response
let _ignore = rx.await;

if ty.is_outbound() {
match target {
TargetProtocol::All => {
Expand All @@ -726,13 +749,6 @@ where
}

crate::runtime::spawn(session.for_each(|_| future::ready(())));

let _ignore = self
.handle_sender
.send(ServiceEventAndError::Event(ServiceEvent::SessionOpen {
session_context,
}))
.await;
}

/// Close the specified session, clean up the handle
Expand All @@ -757,9 +773,12 @@ where
// Service handle processing flow
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Event(ServiceEvent::SessionClose {
session_context: session_control.inner,
}))
.send(
ServiceEvent::SessionClose {
session_context: session_control.inner,
}
.into(),
)
.await;
}
}
Expand Down Expand Up @@ -857,12 +876,13 @@ where
{
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(
.send(
ServiceError::ProtocolHandleError {
proto_id: *proto_id,
error: ProtocolHandleErrorKind::AbnormallyClosed(None),
},
))
}
.into(),
)
.await;
error = true;
}
Expand All @@ -879,12 +899,13 @@ where
error = true;
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(
.send(
ServiceError::ProtocolHandleError {
proto_id: *proto_id,
error: ProtocolHandleErrorKind::AbnormallyClosed(Some(*session_id)),
},
))
}
.into(),
)
.await;
}
}
Expand Down Expand Up @@ -921,10 +942,13 @@ where
self.dial_protocols.remove(&address);
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::DialerError {
address,
error: DialerErrorKind::HandshakeError(error),
}))
.send(
ServiceError::DialerError {
address,
error: DialerErrorKind::HandshakeError(error),
}
.into(),
)
.await;
}
}
Expand All @@ -935,12 +959,13 @@ where
if let Some(session_control) = self.sessions.get(&id) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(
.send(
ServiceError::ProtocolSelectError {
proto_name,
session_context: Arc::clone(&session_control.inner),
},
))
}
.into(),
)
.await;
}
}
Expand All @@ -951,32 +976,41 @@ where
} => {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::ProtocolError {
id,
proto_id,
error,
}))
.send(
ServiceError::ProtocolError {
id,
proto_id,
error,
}
.into(),
)
.await;
}
SessionEvent::DialError { address, error } => {
self.state.decrease();
self.dial_protocols.remove(&address);
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::DialerError {
address,
error: DialerErrorKind::TransportError(error),
}))
.send(
ServiceError::DialerError {
address,
error: DialerErrorKind::TransportError(error),
}
.into(),
)
.await;
}
#[cfg(not(target_family = "wasm"))]
SessionEvent::ListenError { address, error } => {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::ListenError {
address: address.clone(),
error: ListenErrorKind::TransportError(error),
}))
.send(
ServiceError::ListenError {
address: address.clone(),
error: ListenErrorKind::TransportError(error),
}
.into(),
)
.await;
if self.listens.remove(&address) {
#[cfg(feature = "upnp")]
Expand All @@ -986,9 +1020,7 @@ where

let _ignore = self
.handle_sender
.send(ServiceEventAndError::Event(ServiceEvent::ListenClose {
address,
}))
.send(ServiceEvent::ListenClose { address }.into())
.await;
} else {
// try start listen error
Expand All @@ -999,20 +1031,26 @@ where
if let Some(session_control) = self.sessions.get(&id) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::SessionTimeout {
session_context: Arc::clone(&session_control.inner),
}))
.send(
ServiceError::SessionTimeout {
session_context: Arc::clone(&session_control.inner),
}
.into(),
)
.await;
}
}
SessionEvent::MuxerError { id, error } => {
if let Some(session_control) = self.sessions.get(&id) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::MuxerError {
session_context: Arc::clone(&session_control.inner),
error,
}))
.send(
ServiceError::MuxerError {
session_context: Arc::clone(&session_control.inner),
error,
}
.into(),
)
.await;
}
}
Expand All @@ -1023,9 +1061,12 @@ where
} => {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Event(ServiceEvent::ListenStarted {
address: listen_address.clone(),
}))
.send(
ServiceEvent::ListenStarted {
address: listen_address.clone(),
}
.into(),
)
.await;
self.listens.insert(listen_address.clone());
self.state.decrease();
Expand All @@ -1039,9 +1080,7 @@ where
SessionEvent::ProtocolHandleError { error, proto_id } => {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(
ServiceError::ProtocolHandleError { error, proto_id },
))
.send(ServiceError::ProtocolHandleError { error, proto_id }.into())
.await;
// if handle panic, close service
self.handle_service_task(ServiceTask::Shutdown(false), Priority::High)
Expand All @@ -1051,9 +1090,12 @@ where
if let Some(session) = self.sessions.get(&id) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::SessionBlocked {
session_context: session.inner.clone(),
}))
.send(
ServiceError::SessionBlocked {
session_context: session.inner.clone(),
}
.into(),
)
.await;
}
}
Expand All @@ -1077,10 +1119,13 @@ where
if let Err(e) = self.dial_inner(address.clone(), target) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::DialerError {
address,
error: DialerErrorKind::TransportError(e),
}))
.send(
ServiceError::DialerError {
address,
error: DialerErrorKind::TransportError(e),
}
.into(),
)
.await;
}
}
Expand All @@ -1090,10 +1135,13 @@ where
if let Err(e) = self.listen_inner(address.clone()) {
let _ignore = self
.handle_sender
.send(ServiceEventAndError::Error(ServiceError::ListenError {
address,
error: ListenErrorKind::TransportError(e),
}))
.send(
ServiceError::ListenError {
address,
error: ListenErrorKind::TransportError(e),
}
.into(),
)
.await;
}
}
Expand Down Expand Up @@ -1176,9 +1224,7 @@ where
let mut events = futures::stream::iter(
self.listens
.drain()
.map(|address| {
ServiceEventAndError::Event(ServiceEvent::ListenClose { address })
})
.map(|address| ServiceEvent::ListenClose { address }.into())
.collect::<Vec<_>>(),
)
.map(Ok);
Expand Down
33 changes: 32 additions & 1 deletion tentacle/src/service/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,44 @@ use bytes::Bytes;

#[derive(Debug)]
pub(crate) enum ServiceEventAndError {
Event(ServiceEvent),
Event {
event: ServiceEvent,
wait_response: Option<futures::channel::oneshot::Sender<()>>,
},
Error(ServiceError),
Update {
listen_addrs: Vec<multiaddr::MultiAddr>,
},
}

impl ServiceEventAndError {
pub fn wait_response(self, tx: futures::channel::oneshot::Sender<()>) -> Self {
if let ServiceEventAndError::Event { event, .. } = self {
ServiceEventAndError::Event {
event,
wait_response: Some(tx),
}
} else {
self
}
}
}

impl From<ServiceEvent> for ServiceEventAndError {
fn from(event: ServiceEvent) -> Self {
ServiceEventAndError::Event {
event,
wait_response: None,
}
}
}

impl From<ServiceError> for ServiceEventAndError {
fn from(event: ServiceError) -> Self {
ServiceEventAndError::Error(event)
}
}

/// Error generated by the Service
#[derive(Debug)]
pub enum ServiceError {
Expand Down
Loading

0 comments on commit 8cf4f13

Please sign in to comment.