Skip to content

Commit

Permalink
feat: enable tcp base protocol listen on same port
Browse files Browse the repository at this point in the history
  • Loading branch information
driftluo committed Nov 27, 2024
1 parent 2f4a0a1 commit 25607f8
Show file tree
Hide file tree
Showing 10 changed files with 613 additions and 412 deletions.
3 changes: 2 additions & 1 deletion tentacle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ nohash-hasher = "0.2"

parking_lot = { version = "0.12", optional = true }
tokio-tungstenite = { version = "0.24", optional = true }
httparse = { version = "1.9", optional = true }
futures-timer = { version = "3.0.2", optional = true }
async-std = { version = "1", features = ["unstable"], optional = true }
async-io = { version = "1", optional = true }
Expand Down Expand Up @@ -76,7 +77,7 @@ nix = { version = "0.29", default-features = false, features = ["signal"] }

[features]
default = ["tokio-runtime", "tokio-timer"]
ws = ["tokio-tungstenite"]
ws = ["tokio-tungstenite", "httparse"]
tls = ["tokio-rustls"]
upnp = ["igd"]
secio-async-trait = ["secio/async-trait"]
Expand Down
2 changes: 1 addition & 1 deletion tentacle/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ fn server() {
.unwrap();
#[cfg(feature = "ws")]
service
.listen("/ip4/127.0.0.1/tcp/1338/ws".parse().unwrap())
.listen("/ip4/127.0.0.1/tcp/1337/ws".parse().unwrap())
.await
.unwrap();
service.run().await
Expand Down
3 changes: 1 addition & 2 deletions tentacle/src/lock/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#![allow(dead_code)]
#![allow(dead_code, unused_imports)]

#[cfg(feature = "parking_lot")]
pub use parking_lot::{const_fair_mutex, const_mutex, const_rwlock, FairMutex, Mutex, RwLock};
#[cfg(not(feature = "parking_lot"))]
pub mod native;

#[allow(unused_imports)]
#[cfg(not(feature = "parking_lot"))]
pub use native::{Mutex, RwLock};
10 changes: 7 additions & 3 deletions tentacle/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ where
}
inner.listens.insert(listen_address.clone());

inner.spawn_listener(incoming, listen_address);
if !matches!(incoming, MultiIncoming::TcpUpgrade) {
inner.spawn_listener(incoming, listen_address);
}

Ok(addr)
}
Expand Down Expand Up @@ -1017,7 +1019,7 @@ where
if let Some(ref mut client) = self.igd_client {
client.remove(&address);
}

self.try_update_listens().await;
let _ignore = self
.handle_sender
.send(ServiceEvent::ListenClose { address }.into())
Expand Down Expand Up @@ -1075,7 +1077,9 @@ where
if let Some(client) = self.igd_client.as_mut() {
client.register(&listen_address)
}
self.spawn_listener(incoming, listen_address);
if !matches!(incoming, MultiIncoming::TcpUpgrade) {
self.spawn_listener(incoming, listen_address);
}
}
SessionEvent::ProtocolHandleError { error, proto_id } => {
let _ignore = self
Expand Down
128 changes: 61 additions & 67 deletions tentacle/src/transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod browser;
mod memory;
#[cfg(not(target_family = "wasm"))]
mod tcp;
#[cfg(not(target_family = "wasm"))]
mod tcp_base_listen;
#[cfg(all(feature = "tls", not(target_family = "wasm")))]
mod tls;
#[cfg(all(feature = "ws", not(target_family = "wasm")))]
Expand Down Expand Up @@ -93,24 +95,36 @@ pub fn find_type(addr: &Multiaddr) -> TransportType {
.unwrap_or(TransportType::Tcp)
}

pub(crate) fn parse_tls_domain_name(addr: &Multiaddr) -> Option<String> {
let mut iter = addr.iter();

iter.find_map(|proto| {
if let Protocol::Tls(s) = proto {
Some(s.to_string())
} else {
None
}
})
}

#[cfg(not(target_family = "wasm"))]
mod os {
use super::*;

use crate::{
runtime::{TcpListener, TcpStream},
service::config::TcpConfig,
utils::socketaddr_to_multiaddr,
};

use futures::{prelude::Stream, FutureExt, StreamExt};
use log::debug;
use std::{
collections::HashMap,
fmt,
future::Future,
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
Expand All @@ -120,17 +134,21 @@ mod os {
MemoryDialFuture, MemoryListenFuture, MemoryListener, MemorySocket, MemoryTransport,
};
use self::tcp::{TcpDialFuture, TcpListenFuture, TcpTransport};
use self::tcp_base_listen::{
TcpBaseListener, TcpBaseListenerEnum, UpgradeMode, UpgradeModeEnum,
};
#[cfg(feature = "tls")]
use self::tls::{TlsDialFuture, TlsListenFuture, TlsListener, TlsStream, TlsTransport};
use self::tls::{TlsDialFuture, TlsStream, TlsTransport};
#[cfg(feature = "ws")]
use self::ws::{WebsocketListener, WsDialFuture, WsListenFuture, WsStream, WsTransport};
use self::ws::{WsDialFuture, WsStream, WsTransport};
#[cfg(feature = "tls")]
use crate::service::config::TlsConfig;

#[derive(Clone)]
pub(crate) struct MultiTransport {
timeout: Duration,
tcp_config: TcpConfig,
listens_upgrade_modes: Arc<crate::lock::Mutex<HashMap<SocketAddr, UpgradeMode>>>,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
}
Expand All @@ -140,6 +158,7 @@ mod os {
MultiTransport {
timeout,
tcp_config,
listens_upgrade_modes: Arc::new(crate::lock::Mutex::new(Default::default())),
#[cfg(feature = "tls")]
tls_config: None,
}
Expand All @@ -159,15 +178,27 @@ mod os {
fn listen(self, address: Multiaddr) -> Result<Self::ListenFuture> {
match find_type(&address) {
TransportType::Tcp => {
match TcpTransport::new(self.timeout, self.tcp_config.tcp).listen(address) {
match TcpTransport::new(self.timeout, self.tcp_config.tcp)
.listen_upgrade_modes(
UpgradeModeEnum::OnlyTcp.into(),
self.listens_upgrade_modes,
)
.listen(address)
{
Ok(future) => Ok(MultiListenFuture::Tcp(future)),
Err(e) => Err(e),
}
}
#[cfg(feature = "ws")]
TransportType::Ws => {
match WsTransport::new(self.timeout, self.tcp_config.ws).listen(address) {
Ok(future) => Ok(MultiListenFuture::Ws(future)),
match TcpTransport::new(self.timeout, self.tcp_config.ws)
.listen_upgrade_modes(
UpgradeModeEnum::OnlyWs.into(),
self.listens_upgrade_modes,
)
.listen(address)
{
Ok(future) => Ok(MultiListenFuture::Tcp(future)),
Err(e) => Err(e),
}
}
Expand All @@ -183,9 +214,17 @@ mod os {
let tls_config = self.tls_config.ok_or_else(|| {
TransportErrorKind::TlsError("tls config is not set".to_string())
})?;
TlsTransport::new(self.timeout, tls_config, self.tcp_config.tls)
match TcpTransport::new(self.timeout, self.tcp_config.tls)
.listen_upgrade_modes(
UpgradeModeEnum::OnlyTLS.into(),
self.listens_upgrade_modes,
)
.tls_config(tls_config)
.listen(address)
.map(MultiListenFuture::Tls)
{
Ok(future) => Ok(MultiListenFuture::Tcp(future)),
Err(e) => Err(e),
}
}
#[cfg(not(feature = "tls"))]
TransportType::Tls => Err(TransportErrorKind::NotSupported(address)),
Expand Down Expand Up @@ -232,35 +271,24 @@ mod os {
pub enum MultiListenFuture {
Tcp(TcpListenFuture),
Memory(MemoryListenFuture),
#[cfg(feature = "ws")]
Ws(WsListenFuture),
#[cfg(feature = "tls")]
Tls(TlsListenFuture),
}

impl Future for MultiListenFuture {
type Output = Result<(Multiaddr, MultiIncoming)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
MultiListenFuture::Tcp(inner) => Pin::new(
&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Tcp(res.1)))),
)
MultiListenFuture::Tcp(inner) => Pin::new(&mut inner.map(|res| {
res.map(|res| match res.1 {
TcpBaseListenerEnum::New(i) => (res.0, MultiIncoming::Tcp(i)),
TcpBaseListenerEnum::Upgrade => (res.0, MultiIncoming::TcpUpgrade),
})
}))
.poll(cx),
MultiListenFuture::Memory(inner) => Pin::new(
&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Memory(res.1)))),
)
.poll(cx),
#[cfg(feature = "ws")]
MultiListenFuture::Ws(inner) => {
Pin::new(&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Ws(res.1)))))
.poll(cx)
}
#[cfg(feature = "tls")]
MultiListenFuture::Tls(inner) => Pin::new(
&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Tls(res.1)))),
)
.poll(cx),
}
}
}
Expand Down Expand Up @@ -381,59 +409,25 @@ mod os {
}

pub enum MultiIncoming {
Tcp(TcpListener),
TcpUpgrade,
Tcp(TcpBaseListener),
Memory(MemoryListener),
#[cfg(feature = "ws")]
Ws(WebsocketListener),
#[cfg(feature = "tls")]
Tls(TlsListener),
}

impl Stream for MultiIncoming {
type Item = std::result::Result<(Multiaddr, MultiStream), io::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.get_mut() {
MultiIncoming::Tcp(inner) => {
loop {
match inner.poll_accept(cx)? {
// Why can't get the peer address of the connected stream ?
// Error will be "Transport endpoint is not connected",
// so why incoming will appear unconnected stream ?
Poll::Ready((stream, _)) => match stream.peer_addr() {
Ok(remote_address) => {
break Poll::Ready(Some(Ok((
socketaddr_to_multiaddr(remote_address),
MultiStream::Tcp(stream),
))))
}
Err(err) => {
debug!("stream get peer address error: {:?}", err);
}
},
Poll::Pending => break Poll::Pending,
}
}
}
MultiIncoming::Memory(inner) => match inner.poll_next_unpin(cx)? {
Poll::Ready(Some((addr, stream))) => {
Poll::Ready(Some(Ok((addr, MultiStream::Memory(stream)))))
}
MultiIncoming::Tcp(inner) => match inner.poll_next_unpin(cx)? {
Poll::Ready(Some((addr, stream))) => Poll::Ready(Some(Ok((addr, stream)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
#[cfg(feature = "ws")]
MultiIncoming::Ws(inner) => match inner.poll_next_unpin(cx)? {
Poll::Ready(Some((addr, stream))) => {
Poll::Ready(Some(Ok((addr, MultiStream::Ws(Box::new(stream))))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
},
#[cfg(feature = "tls")]
MultiIncoming::Tls(inner) => match inner.poll_next_unpin(cx)? {
MultiIncoming::TcpUpgrade => unreachable!(),
MultiIncoming::Memory(inner) => match inner.poll_next_unpin(cx)? {
Poll::Ready(Some((addr, stream))) => {
Poll::Ready(Some(Ok((addr, MultiStream::Tls(stream)))))
Poll::Ready(Some(Ok((addr, MultiStream::Memory(stream)))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
Expand Down
Loading

0 comments on commit 25607f8

Please sign in to comment.