diff --git a/Cargo.toml b/Cargo.toml index 05b5ccc8860..eac6e34607a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,15 +14,16 @@ default = [ "deflate", "dns-async-std", "floodsub", + "gossipsub", "identify", "kad", - "gossipsub", "mdns", "mplex", "noise", "ping", "plaintext", "pnet", + "quic", "relay", "request-response", "rendezvous", @@ -37,9 +38,9 @@ deflate = ["libp2p-deflate"] dns-async-std = ["libp2p-dns", "libp2p-dns/async-std"] dns-tokio = ["libp2p-dns", "libp2p-dns/tokio"] floodsub = ["libp2p-floodsub"] +gossipsub = ["libp2p-gossipsub"] identify = ["libp2p-identify", "libp2p-metrics/identify"] kad = ["libp2p-kad", "libp2p-metrics/kad"] -gossipsub = ["libp2p-gossipsub"] metrics = ["libp2p-metrics"] mdns = ["libp2p-mdns"] mplex = ["libp2p-mplex"] @@ -47,6 +48,7 @@ noise = ["libp2p-noise"] ping = ["libp2p-ping", "libp2p-metrics/ping"] plaintext = ["libp2p-plaintext"] pnet = ["libp2p-pnet"] +quic = ["libp2p-quic"] relay = ["libp2p-relay"] request-response = ["libp2p-request-response"] rendezvous = ["libp2p-rendezvous"] @@ -97,6 +99,7 @@ wasm-timer = "0.2.4" libp2p-deflate = { version = "0.30.0", path = "transports/deflate", optional = true } libp2p-dns = { version = "0.30.0", path = "transports/dns", optional = true, default-features = false } libp2p-mdns = { version = "0.32.0", path = "protocols/mdns", optional = true } +libp2p-quic = { version = "0.6.0", path = "transports/quic", optional = true } libp2p-tcp = { version = "0.30.0", path = "transports/tcp", default-features = false, optional = true } libp2p-websocket = { version = "0.31.0", path = "transports/websocket", optional = true } @@ -132,6 +135,7 @@ members = [ "transports/noise", "transports/plaintext", "transports/pnet", + "transports/quic", "transports/tcp", "transports/uds", "transports/websocket", diff --git a/src/lib.rs b/src/lib.rs index b8728005ea3..e75c306e245 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,11 @@ pub use libp2p_plaintext as plaintext; #[cfg_attr(docsrs, doc(cfg(feature = "pnet")))] #[doc(inline)] pub use libp2p_pnet as pnet; +#[cfg(feature = "quic")] +#[cfg_attr(docsrs, doc(cfg(feature = "quic")))] +#[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] +#[doc(inline)] +pub use libp2p_quic as quic; #[cfg(feature = "relay")] #[cfg_attr(docsrs, doc(cfg(feature = "relay")))] #[doc(inline)] diff --git a/transports/quic/Cargo.toml b/transports/quic/Cargo.toml new file mode 100644 index 00000000000..f02545e13dd --- /dev/null +++ b/transports/quic/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "libp2p-quic" +version = "0.6.0" +authors = ["David Craven ", "Parity Technologies "] +edition = "2018" +description = "TLS and Noise based QUIC transport implementation for libp2p" +repository = "https://github.com/libp2p/rust-libp2p" +license = "MIT" + +[features] +noise = ["quinn-noise", "ed25519-dalek"] +tls = ["barebones-x509", "quinn-proto/tls-rustls", "rcgen", "ring", "rustls", "untrusted", "webpki", "yasna"] + +[dependencies] +async-global-executor = "2.0.2" +async-io = "1.6.0" +barebones-x509 = { version = "0.5.0", optional = true, features = ["webpki", "rustls", "std"] } +bytes = "1.0.1" +ed25519-dalek = { version = "1.0.1", optional = true } +futures = "0.3.15" +if-watch = "0.2.2" +libp2p-core = { version = "0.30.0", path = "../../core" } +multihash = { version = "0.14.0", default-features = false } +parking_lot = "0.11.1" +quinn-noise = { version = "0.3.0", optional = true } +quinn-proto = { version = "0.7.3", default-features = false } +rcgen = { version = "0.8.11", optional = true } +ring = { version = "0.16.20", optional = true } +rustls = { version = "0.19.1", optional = true, features = ["dangerous_configuration"] } +thiserror = "1.0.26" +tracing = "0.1.26" +udp-socket = "0.1.5" +untrusted = { version = "0.7.1", optional = true } +webpki = { version = "0.21.4", optional = true, features = ["std"] } +yasna = { version = "0.4.0", optional = true } + +[dev-dependencies] +anyhow = "1.0.41" +async-std = { version = "1.9.0", features = ["attributes"] } +async-trait = "0.1.50" +libp2p = { version = "0.40.0", default-features = false, features = ["request-response"], path = "../.." } +log-panics = "2.0.0" +rand = "0.8.4" +rand_core = "0.5.1" +tracing-subscriber = "0.2.19" diff --git a/transports/quic/src/crypto.rs b/transports/quic/src/crypto.rs new file mode 100644 index 00000000000..81d1c8193b3 --- /dev/null +++ b/transports/quic/src/crypto.rs @@ -0,0 +1,199 @@ +use libp2p_core::PeerId; +use quinn_proto::crypto::Session; +use quinn_proto::TransportConfig; +use std::sync::Arc; + +pub struct CryptoConfig { + pub keypair: C::Keypair, + pub psk: Option<[u8; 32]>, + pub keylogger: Option, + pub transport: Arc, +} + +#[cfg(feature = "noise")] +trait CloneKeypair { + fn clone_keypair(&self) -> Self; +} + +#[cfg(feature = "noise")] +impl CloneKeypair for ed25519_dalek::Keypair { + fn clone_keypair(&self) -> Self { + ed25519_dalek::Keypair::from_bytes(&self.to_bytes()).expect("serde works") + } +} + +pub trait ToLibp2p { + fn to_public(&self) -> libp2p_core::identity::PublicKey; + fn to_peer_id(&self) -> PeerId { + self.to_public().to_peer_id() + } +} + +#[cfg(feature = "noise")] +impl ToLibp2p for ed25519_dalek::Keypair { + fn to_public(&self) -> libp2p_core::identity::PublicKey { + self.public.to_public() + } +} + +#[cfg(feature = "noise")] +impl ToLibp2p for ed25519_dalek::PublicKey { + fn to_public(&self) -> libp2p_core::identity::PublicKey { + let public_key = self.to_bytes(); + let public_key = + libp2p_core::identity::ed25519::PublicKey::decode(&public_key[..]).unwrap(); + libp2p_core::identity::PublicKey::Ed25519(public_key) + } +} + +#[cfg(feature = "tls")] +impl ToLibp2p for libp2p_core::identity::Keypair { + fn to_public(&self) -> libp2p_core::identity::PublicKey { + self.public() + } +} + +pub trait Crypto: std::fmt::Debug + Clone + 'static { + type Session: Session + Unpin; + type Keylogger: Send + Sync; + type Keypair: Send + Sync + ToLibp2p; + type PublicKey: Send + std::fmt::Debug + PartialEq; + + fn new_server_config( + config: &Arc>, + ) -> ::ServerConfig; + fn new_client_config( + config: &Arc>, + remote_public: Self::PublicKey, + ) -> ::ClientConfig; + fn supported_quic_versions() -> Vec; + fn default_quic_version() -> u32; + fn peer_id(session: &Self::Session) -> Option; + fn extract_public_key(generic_key: libp2p_core::PublicKey) -> Option; + fn keylogger() -> Self::Keylogger; +} + +#[cfg(feature = "noise")] +#[derive(Clone, Copy, Debug)] +pub struct NoiseCrypto; + +#[cfg(feature = "noise")] +impl Crypto for NoiseCrypto { + type Session = quinn_noise::NoiseSession; + type Keylogger = Arc; + type Keypair = ed25519_dalek::Keypair; + type PublicKey = ed25519_dalek::PublicKey; + + fn new_server_config( + config: &Arc>, + ) -> ::ServerConfig { + Arc::new( + quinn_noise::NoiseServerConfig { + keypair: config.keypair.clone_keypair(), + psk: config.psk, + keylogger: config.keylogger.clone(), + supported_protocols: vec![b"libp2p".to_vec()], + } + .into(), + ) + } + + fn new_client_config( + config: &Arc>, + remote_public_key: Self::PublicKey, + ) -> ::ClientConfig { + quinn_noise::NoiseClientConfig { + keypair: config.keypair.clone_keypair(), + psk: config.psk, + alpn: b"libp2p".to_vec(), + remote_public_key, + keylogger: config.keylogger.clone(), + } + .into() + } + + fn supported_quic_versions() -> Vec { + quinn_noise::SUPPORTED_QUIC_VERSIONS.to_vec() + } + + fn default_quic_version() -> u32 { + quinn_noise::DEFAULT_QUIC_VERSION + } + + fn peer_id(session: &Self::Session) -> Option { + Some(session.peer_identity()?.to_peer_id()) + } + + fn extract_public_key(generic_key: libp2p_core::PublicKey) -> Option { + let public_key = if let libp2p_core::PublicKey::Ed25519(public_key) = generic_key { + public_key.encode() + } else { + return None; + }; + Self::PublicKey::from_bytes(&public_key).ok() + } + + fn keylogger() -> Self::Keylogger { + Arc::new(quinn_noise::KeyLogFile::new()) + } +} + +#[cfg(feature = "tls")] +#[derive(Clone, Copy, Debug)] +pub struct TlsCrypto; + +#[cfg(feature = "tls")] +impl Crypto for TlsCrypto { + type Session = quinn_proto::crypto::rustls::TlsSession; + type Keylogger = Arc; + type Keypair = libp2p_core::identity::Keypair; + type PublicKey = libp2p_core::identity::PublicKey; + + fn new_server_config( + config: &Arc>, + ) -> ::ServerConfig { + assert!(config.psk.is_none(), "invalid config"); + let mut server = crate::tls::make_server_config(&config.keypair).expect("invalid config"); + if let Some(key_log) = config.keylogger.clone() { + server.key_log = key_log; + } + Arc::new(server) + } + + fn new_client_config( + config: &Arc>, + remote_public: Self::PublicKey, + ) -> ::ClientConfig { + assert!(config.psk.is_none(), "invalid config"); + let mut client = + crate::tls::make_client_config(&config.keypair, remote_public.to_peer_id()) + .expect("invalid config"); + if let Some(key_log) = config.keylogger.clone() { + client.key_log = key_log; + } + Arc::new(client) + } + + fn supported_quic_versions() -> Vec { + quinn_proto::DEFAULT_SUPPORTED_VERSIONS.to_vec() + } + + fn default_quic_version() -> u32 { + quinn_proto::DEFAULT_SUPPORTED_VERSIONS[0] + } + + fn peer_id(session: &Self::Session) -> Option { + let certificate = session.get_peer_certificates()?.into_iter().next()?; + Some(crate::tls::extract_peerid_or_panic( + quinn_proto::Certificate::from(certificate).as_der(), + )) + } + + fn extract_public_key(generic_key: libp2p_core::PublicKey) -> Option { + Some(generic_key) + } + + fn keylogger() -> Self::Keylogger { + Arc::new(rustls::KeyLogFile::new()) + } +} diff --git a/transports/quic/src/endpoint.rs b/transports/quic/src/endpoint.rs new file mode 100644 index 00000000000..3295b5bbf58 --- /dev/null +++ b/transports/quic/src/endpoint.rs @@ -0,0 +1,471 @@ +use crate::crypto::{Crypto, CryptoConfig}; +use crate::muxer::QuicMuxer; +use crate::{QuicConfig, QuicError}; +use futures::channel::{mpsc, oneshot}; +use futures::prelude::*; +use quinn_proto::crypto::Session; +use quinn_proto::generic::{ClientConfig, ServerConfig}; +use quinn_proto::{ + ConnectionEvent, ConnectionHandle, DatagramEvent, EcnCodepoint, EndpointEvent, Transmit, +}; +use std::collections::{HashMap, VecDeque}; +use std::io::IoSliceMut; +use std::mem::MaybeUninit; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Instant; +use udp_socket::{RecvMeta, SocketType, UdpCapabilities, UdpSocket, BATCH_SIZE}; + +/// Message sent to the endpoint background task. +#[derive(Debug)] +enum ToEndpoint { + /// Instructs the endpoint to start connecting to the given address. + Dial { + /// UDP address to connect to. + addr: SocketAddr, + /// The remotes public key. + public_key: C::PublicKey, + /// Channel to return the result of the dialing to. + tx: oneshot::Sender, QuicError>>, + }, + /// Sent by a `quinn_proto` connection when the endpoint needs to process an event generated + /// by a connection. The event itself is opaque to us. + ConnectionEvent { + connection_id: ConnectionHandle, + event: EndpointEvent, + }, + /// Instruct the endpoint to transmit a packet on its UDP socket. + Transmit(Transmit), +} + +#[derive(Debug)] +pub struct TransportChannel { + tx: mpsc::UnboundedSender>, + rx: mpsc::Receiver, QuicError>>, + port: u16, + ty: SocketType, +} + +impl TransportChannel { + pub fn dial( + &mut self, + addr: SocketAddr, + public_key: C::PublicKey, + ) -> oneshot::Receiver, QuicError>> { + let (tx, rx) = oneshot::channel(); + let msg = ToEndpoint::Dial { + addr, + public_key, + tx, + }; + self.tx.unbounded_send(msg).expect("endpoint has crashed"); + rx + } + + pub fn poll_incoming( + &mut self, + cx: &mut Context, + ) -> Poll, QuicError>>> { + Pin::new(&mut self.rx).poll_next(cx) + } + + pub fn port(&self) -> u16 { + self.port + } + + pub fn ty(&self) -> SocketType { + self.ty + } +} + +#[derive(Debug)] +pub struct ConnectionChannel { + id: ConnectionHandle, + tx: mpsc::UnboundedSender>, + rx: mpsc::Receiver, + port: u16, + max_datagrams: usize, +} + +impl ConnectionChannel { + pub fn poll_channel_events(&mut self, cx: &mut Context) -> Poll { + match Pin::new(&mut self.rx).poll_next(cx) { + Poll::Ready(Some(event)) => Poll::Ready(event), + Poll::Ready(None) => panic!("endpoint has crashed"), + Poll::Pending => Poll::Pending, + } + } + + pub fn send_endpoint_event(&mut self, event: EndpointEvent) { + let msg = ToEndpoint::ConnectionEvent { + connection_id: self.id, + event, + }; + self.tx.unbounded_send(msg).expect("endpoint has crashed") + } + + pub fn send_transmit(&mut self, transmit: Transmit) { + let msg = ToEndpoint::Transmit(transmit); + self.tx.unbounded_send(msg).expect("endpoint has crashed") + } + + pub fn port(&self) -> u16 { + self.port + } + + pub fn max_datagrams(&self) -> usize { + self.max_datagrams + } +} + +#[derive(Debug)] +struct EndpointChannel { + rx: mpsc::UnboundedReceiver>, + tx: mpsc::Sender, QuicError>>, + port: u16, + max_datagrams: usize, + connection_tx: mpsc::UnboundedSender>, +} + +impl EndpointChannel { + pub fn poll_next_event(&mut self, cx: &mut Context) -> Poll>> { + Pin::new(&mut self.rx).poll_next(cx) + } + + pub fn create_connection( + &self, + id: ConnectionHandle, + ) -> (ConnectionChannel, mpsc::Sender) { + let (tx, rx) = mpsc::channel(12); + let channel = ConnectionChannel { + id, + tx: self.connection_tx.clone(), + rx, + port: self.port, + max_datagrams: self.max_datagrams, + }; + (channel, tx) + } +} + +type QuinnEndpointConfig = quinn_proto::generic::EndpointConfig; +type QuinnEndpoint = quinn_proto::generic::Endpoint; + +pub struct EndpointConfig { + socket: UdpSocket, + endpoint: QuinnEndpoint, + port: u16, + crypto_config: Arc>, + capabilities: UdpCapabilities, +} + +impl EndpointConfig { + pub fn new(mut config: QuicConfig, addr: SocketAddr) -> Result { + config.transport.max_concurrent_uni_streams(0)?; + config.transport.datagram_receive_buffer_size(None); + let transport = Arc::new(config.transport); + + let crypto_config = Arc::new(CryptoConfig { + keypair: config.keypair, + psk: config.psk, + keylogger: config.keylogger, + transport: transport.clone(), + }); + + let mut server_config = ServerConfig::::default(); + server_config.transport = transport; + server_config.crypto = C::new_server_config(&crypto_config); + + let mut endpoint_config = QuinnEndpointConfig::default(); + endpoint_config + .supported_versions(C::supported_quic_versions(), C::default_quic_version())?; + + let socket = UdpSocket::bind(addr)?; + let port = socket.local_addr()?.port(); + let endpoint = quinn_proto::generic::Endpoint::new( + Arc::new(endpoint_config), + Some(Arc::new(server_config)), + ); + let capabilities = UdpSocket::capabilities()?; + Ok(Self { + socket, + endpoint, + port, + crypto_config, + capabilities, + }) + } + + pub fn spawn(self) -> TransportChannel + where + ::ClientConfig: Send + Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, + { + let (tx1, rx1) = mpsc::unbounded(); + let (tx2, rx2) = mpsc::channel(1); + let transport = TransportChannel { + tx: tx1, + rx: rx2, + port: self.port, + ty: self.socket.socket_type(), + }; + let endpoint = EndpointChannel { + tx: tx2, + rx: rx1, + port: self.port, + max_datagrams: self.capabilities.max_gso_segments, + connection_tx: transport.tx.clone(), + }; + async_global_executor::spawn(Endpoint::new(endpoint, self)).detach(); + transport + } +} + +struct Endpoint { + channel: EndpointChannel, + endpoint: QuinnEndpoint, + socket: UdpSocket, + crypto_config: Arc>, + connections: HashMap>, + outgoing: VecDeque, + recv_buf: Box<[u8]>, + incoming_slot: Option>, + event_slot: Option<(ConnectionHandle, ConnectionEvent)>, +} + +impl Endpoint { + pub fn new(channel: EndpointChannel, config: EndpointConfig) -> Self { + let max_udp_payload_size = config + .endpoint + .config() + .get_max_udp_payload_size() + .min(u16::MAX as _) as usize; + let recv_buf = vec![0; max_udp_payload_size * BATCH_SIZE].into_boxed_slice(); + Self { + channel, + endpoint: config.endpoint, + socket: config.socket, + crypto_config: config.crypto_config, + connections: Default::default(), + outgoing: Default::default(), + recv_buf, + incoming_slot: None, + event_slot: None, + } + } + + pub fn transmit(&mut self, transmit: Transmit) { + let ecn = transmit + .ecn + .map(|ecn| udp_socket::EcnCodepoint::from_bits(ecn as u8)) + .unwrap_or_default(); + let transmit = udp_socket::Transmit { + destination: transmit.destination, + contents: transmit.contents, + ecn, + segment_size: transmit.segment_size, + src_ip: transmit.src_ip, + }; + self.outgoing.push_back(transmit); + } + + fn send_incoming(&mut self, muxer: QuicMuxer, cx: &mut Context) -> bool { + assert!(self.incoming_slot.is_none()); + match self.channel.tx.poll_ready(cx) { + Poll::Pending => { + self.incoming_slot = Some(muxer); + true + } + Poll::Ready(Ok(())) => { + self.channel.tx.try_send(Ok(muxer)).ok(); + false + } + Poll::Ready(_err) => false, + } + } + + fn send_event( + &mut self, + id: ConnectionHandle, + event: ConnectionEvent, + cx: &mut Context, + ) -> bool { + assert!(self.event_slot.is_none()); + let conn = self.connections.get_mut(&id).unwrap(); + match conn.poll_ready(cx) { + Poll::Pending => { + self.event_slot = Some((id, event)); + true + } + Poll::Ready(Ok(())) => { + conn.try_send(event).ok(); + false + } + Poll::Ready(_err) => false, + } + } +} + +impl Future for Endpoint +where + ::ClientConfig: Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let me = Pin::into_inner(self); + + if let Some(muxer) = me.incoming_slot.take() { + if !me.send_incoming(muxer, cx) { + tracing::info!("cleared incoming slot"); + } + } + + if let Some((id, event)) = me.event_slot.take() { + if !me.send_event(id, event, cx) { + tracing::info!("cleared event slot"); + } + } + + while let Some(transmit) = me.endpoint.poll_transmit() { + me.transmit(transmit); + } + + if me.event_slot.is_none() { + while let Poll::Ready(event) = me.channel.poll_next_event(cx) { + match event { + Some(ToEndpoint::Dial { + addr, + public_key, + tx, + }) => { + let crypto = C::new_client_config(&me.crypto_config, public_key); + let client_config = ClientConfig { + transport: me.crypto_config.transport.clone(), + crypto, + }; + let (id, connection) = + match me.endpoint.connect(client_config, addr, "server_name") { + Ok(c) => c, + Err(err) => { + tracing::error!("dial failure: {}", err); + let _ = tx.send(Err(err.into())); + continue; + } + }; + let (channel, conn) = me.channel.create_connection(id); + me.connections.insert(id, conn); + let muxer = QuicMuxer::new(channel, connection); + tx.send(Ok(muxer)).ok(); + } + Some(ToEndpoint::ConnectionEvent { + connection_id, + event, + }) => { + let is_drained_event = event.is_drained(); + if is_drained_event { + me.connections.remove(&connection_id); + } + if let Some(event) = me.endpoint.handle_event(connection_id, event) { + if me.send_event(connection_id, event, cx) { + tracing::info!("filled event slot"); + break; + } + } + } + Some(ToEndpoint::Transmit(transmit)) => { + me.transmit(transmit); + } + None => { + me.endpoint.reject_new_connections(); + return Poll::Ready(()); + } + } + } + } + + while !me.outgoing.is_empty() { + let transmits: &[_] = me.outgoing.make_contiguous(); + match me.socket.poll_send(cx, transmits) { + Poll::Ready(Ok(n)) => { + me.outgoing.drain(..n); + } + Poll::Ready(Err(err)) => tracing::error!("send_to: {}", err), + Poll::Pending => break, + } + } + + if me.event_slot.is_none() && me.incoming_slot.is_none() { + let mut metas = [RecvMeta::default(); BATCH_SIZE]; + let mut iovs = MaybeUninit::<[IoSliceMut; BATCH_SIZE]>::uninit(); + fn init_iovs<'a>( + iovs: &'a mut MaybeUninit<[IoSliceMut<'a>; BATCH_SIZE]>, + recv_buf: &'a mut [u8], + ) -> &'a mut [IoSliceMut<'a>] { + let chunk_size = recv_buf.len() / BATCH_SIZE; + let chunks = recv_buf.chunks_mut(chunk_size); + // every iovs elem must be initialized with an according elem from buf chunks + assert_eq!(chunks.len(), BATCH_SIZE); + chunks.enumerate().for_each(|(i, buf)| unsafe { + iovs.as_mut_ptr() + .cast::() + .add(i) + .write(IoSliceMut::new(buf)); + }); + + unsafe { + // SAFETY: all elements are initialized + iovs.assume_init_mut() + } + } + let mut recv_buf = core::mem::replace(&mut me.recv_buf, Vec::new().into_boxed_slice()); + let iovs = init_iovs(&mut iovs, &mut recv_buf); + while let Poll::Ready(result) = me.socket.poll_recv(cx, iovs, &mut metas) { + let n = match result { + Ok(n) => n, + Err(err) => { + tracing::error!("recv_from: {}", err); + continue; + } + }; + for i in 0..n { + let meta = &metas[i]; + let packet = From::from(&iovs[i][..meta.len]); + let ecn = meta + .ecn + .map(|ecn| EcnCodepoint::from_bits(ecn as u8)) + .unwrap_or_default(); + match me + .endpoint + .handle(Instant::now(), meta.source, meta.dst_ip, ecn, packet) + { + None => {} + Some((id, DatagramEvent::ConnectionEvent(event))) => { + if me.send_event(id, event, cx) { + tracing::info!("filled event slot"); + break; + } + } + Some((id, DatagramEvent::NewConnection(connection))) => { + let (channel, tx) = me.channel.create_connection(id); + me.connections.insert(id, tx); + let muxer = QuicMuxer::new(channel, connection); + if me.send_incoming(muxer, cx) { + tracing::info!("filled incoming slot"); + break; + } + } + } + } + } + me.recv_buf = recv_buf; + } + + Poll::Pending + } +} diff --git a/transports/quic/src/lib.rs b/transports/quic/src/lib.rs new file mode 100644 index 00000000000..9185922c6cb --- /dev/null +++ b/transports/quic/src/lib.rs @@ -0,0 +1,86 @@ +mod crypto; +mod endpoint; +mod muxer; +#[cfg(feature = "tls")] +mod tls; +mod transport; + +pub use crate::crypto::Crypto; +#[cfg(feature = "noise")] +pub use crate::crypto::NoiseCrypto; +#[cfg(feature = "tls")] +pub use crate::crypto::TlsCrypto; +pub use crate::crypto::ToLibp2p; +pub use crate::muxer::{QuicMuxer, QuicMuxerError}; +pub use crate::transport::{QuicDial, QuicTransport}; +#[cfg(feature = "noise")] +pub use quinn_noise::{KeyLog, KeyLogFile}; +pub use quinn_proto::{ConfigError, ConnectError, ConnectionError, TransportConfig}; + +use libp2p_core::transport::TransportError; +use libp2p_core::Multiaddr; +use quinn_proto::crypto::Session; +use thiserror::Error; + +/// Quic configuration. +pub struct QuicConfig { + pub keypair: C::Keypair, + pub psk: Option<[u8; 32]>, + pub transport: TransportConfig, + pub keylogger: Option, +} + +impl std::fmt::Debug for QuicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("QuicConfig") + .field("keypair", &self.keypair.to_public()) + .field("psk", &self.psk) + .field("transport", &self.transport) + .finish() + } +} + +impl QuicConfig +where + ::ClientConfig: Send + Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + /// Creates a new config from a keypair. + pub fn new(keypair: C::Keypair) -> Self { + Self { + keypair, + psk: None, + transport: TransportConfig::default(), + keylogger: None, + } + } + + /// Enable keylogging. + pub fn enable_keylogger(&mut self) -> &mut Self { + self.keylogger = Some(C::keylogger()); + self + } + + /// Spawns a new endpoint. + pub async fn listen_on( + self, + addr: Multiaddr, + ) -> Result, TransportError> { + QuicTransport::new(self, addr).await + } +} + +#[derive(Debug, Error)] +pub enum QuicError { + #[error("{0}")] + Config(#[from] ConfigError), + #[error("{0}")] + Connect(#[from] ConnectError), + #[error("{0}")] + Muxer(#[from] QuicMuxerError), + #[error("{0}")] + Io(#[from] std::io::Error), + #[error("a `StreamMuxerEvent` was generated before the handshake was complete.")] + UpgradeError, +} diff --git a/transports/quic/src/muxer.rs b/transports/quic/src/muxer.rs new file mode 100644 index 00000000000..29f6b2a3ced --- /dev/null +++ b/transports/quic/src/muxer.rs @@ -0,0 +1,418 @@ +use crate::crypto::Crypto; +use crate::endpoint::ConnectionChannel; +use async_io::Timer; +use futures::prelude::*; +use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; +use libp2p_core::{Multiaddr, PeerId}; +use parking_lot::Mutex; +use quinn_proto::generic::Connection; +use quinn_proto::{ + ConnectionError, Dir, Event, FinishError, ReadError, ReadableError, StreamEvent, StreamId, + VarInt, WriteError, +}; +use std::collections::{HashMap, VecDeque}; +use std::io::Write; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; +use std::time::Instant; +use thiserror::Error; + +/// State for a single opened QUIC connection. +pub struct QuicMuxer { + inner: Mutex>, +} + +impl std::fmt::Debug for QuicMuxer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "QuicMuxer") + } +} + +/// Mutex protected fields of [`QuicMuxer`]. +struct QuicMuxerInner { + /// Accept incoming streams. + accept_incoming: bool, + /// Endpoint channel. + endpoint: ConnectionChannel, + /// Inner connection object that yields events. + connection: Connection, + /// Connection waker. + waker: Option, + /// Connection timer. + timer: Option, + /// State of all open substreams. + substreams: HashMap, + /// Pending substreams. + pending_substreams: VecDeque, + /// Close waker. + close_waker: Option, +} + +/// State of a single substream. +#[derive(Debug, Default)] +struct SubstreamState { + /// Waker to wake if the substream becomes readable. + read_waker: Option, + /// Waker to wake if the substream becomes writable. + write_waker: Option, +} + +impl QuicMuxer { + pub fn new(endpoint: ConnectionChannel, connection: Connection) -> Self { + Self { + inner: Mutex::new(QuicMuxerInner { + accept_incoming: false, + endpoint, + connection, + waker: None, + timer: None, + substreams: Default::default(), + pending_substreams: Default::default(), + close_waker: None, + }), + } + } + + pub fn is_handshaking(&self) -> bool { + self.inner.lock().connection.is_handshaking() + } + + pub fn peer_id(&self) -> Option { + let inner = self.inner.lock(); + let session = inner.connection.crypto_session(); + C::peer_id(session) + } + + pub fn local_addr(&self) -> Multiaddr { + let inner = self.inner.lock(); + let ip = inner + .connection + .local_ip() + .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + let addr = SocketAddr::new(ip, inner.endpoint.port()); + crate::transport::socketaddr_to_multiaddr(&addr) + } + + pub fn remote_addr(&self) -> Multiaddr { + let inner = self.inner.lock(); + let addr = inner.connection.remote_address(); + crate::transport::socketaddr_to_multiaddr(&addr) + } + + pub(crate) fn set_accept_incoming(&self, accept: bool) { + let mut inner = self.inner.lock(); + inner.accept_incoming = accept; + } +} + +impl StreamMuxer for QuicMuxer { + type Substream = StreamId; + type OutboundSubstream = (); + type Error = QuicMuxerError; + + fn poll_event( + &self, + cx: &mut Context, + ) -> Poll, Self::Error>> { + let mut inner = self.inner.lock(); + let now = Instant::now(); + + while let Poll::Ready(event) = inner.endpoint.poll_channel_events(cx) { + inner.connection.handle_event(event); + } + + let _max_datagrams = inner.endpoint.max_datagrams(); + while let Some(transmit) = inner.connection.poll_transmit(now) { + inner.endpoint.send_transmit(transmit); + } + + loop { + if let Some(timer) = inner.timer.as_mut() { + match Pin::new(timer).poll(cx) { + Poll::Ready(expired) => { + inner.connection.handle_timeout(expired); + inner.timer = None; + } + Poll::Pending => break, + } + } else if let Some(when) = inner.connection.poll_timeout() { + inner.timer = Some(Timer::at(when)); + } else { + break; + } + } + + while let Some(event) = inner.connection.poll_endpoint_events() { + inner.endpoint.send_endpoint_event(event); + } + + while let Some(event) = inner.connection.poll() { + match event { + Event::HandshakeDataReady => {} + Event::Connected => { + // Break here so that the noise upgrade can finish. + return Poll::Pending; + } + Event::ConnectionLost { reason } => { + tracing::debug!("connection lost because of {}", reason); + inner.substreams.clear(); + if let Some(waker) = inner.close_waker.take() { + waker.wake(); + } + return Poll::Ready(Err(QuicMuxerError::ConnectionLost { reason })); + } + Event::Stream(StreamEvent::Opened { dir: Dir::Bi }) => { + // handled at end. + } + Event::Stream(StreamEvent::Readable { id }) => { + tracing::trace!("stream readable {}", id); + if let Some(substream) = inner.substreams.get_mut(&id) { + if let Some(waker) = substream.read_waker.take() { + waker.wake(); + } + } + } + Event::Stream(StreamEvent::Writable { id }) => { + tracing::trace!("stream writable {}", id); + if let Some(substream) = inner.substreams.get_mut(&id) { + if let Some(waker) = substream.write_waker.take() { + waker.wake(); + } + } + } + Event::Stream(StreamEvent::Finished { id }) => { + tracing::trace!("stream finished {}", id); + if let Some(substream) = inner.substreams.get_mut(&id) { + if let Some(waker) = substream.read_waker.take() { + waker.wake(); + } + if let Some(waker) = substream.write_waker.take() { + waker.wake(); + } + } + } + Event::Stream(StreamEvent::Stopped { id, error_code }) => { + tracing::debug!("substream {} stopped with error {}", id, error_code); + inner.substreams.remove(&id); + return Poll::Ready(Err(QuicMuxerError::StreamStopped { id, error_code })); + } + Event::Stream(StreamEvent::Available { dir: Dir::Bi }) => { + tracing::trace!("stream available"); + if let Some(waker) = inner.pending_substreams.pop_front() { + waker.wake(); + } + } + Event::Stream(StreamEvent::Opened { dir: Dir::Uni }) + | Event::Stream(StreamEvent::Available { dir: Dir::Uni }) + | Event::DatagramReceived => { + // We don't use datagrams or unidirectional streams. If these events + // happen, it is by some code not compatible with libp2p-quic. + inner + .connection + .close(Instant::now(), From::from(0u32), Default::default()); + return Poll::Ready(Err(QuicMuxerError::ProtocolViolation)); + } + } + } + + // TODO quinn doesn't support `StreamMuxerEvent::AddressChange`. + + if inner.accept_incoming { + if let Some(id) = inner.connection.streams().accept(Dir::Bi) { + inner.substreams.insert(id, Default::default()); + tracing::trace!("opened incoming substream {}", id); + return Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(id))); + } + } + + if inner.substreams.is_empty() { + if let Some(waker) = inner.close_waker.take() { + waker.wake(); + } + } + + inner.waker = Some(cx.waker().clone()); + Poll::Pending + } + + fn open_outbound(&self) -> Self::OutboundSubstream {} + + fn poll_outbound( + &self, + cx: &mut Context, + _: &mut Self::OutboundSubstream, + ) -> Poll> { + let mut inner = self.inner.lock(); + if let Some(id) = inner.connection.streams().open(Dir::Bi) { + tracing::trace!("opened outgoing substream {}", id); + inner.substreams.insert(id, Default::default()); + if let Some(waker) = inner.pending_substreams.pop_front() { + waker.wake(); + } + Poll::Ready(Ok(id)) + } else { + inner.pending_substreams.push_back(cx.waker().clone()); + Poll::Pending + } + } + + fn destroy_outbound(&self, _: Self::OutboundSubstream) {} + + fn read_substream( + &self, + cx: &mut Context, + id: &mut Self::Substream, + mut buf: &mut [u8], + ) -> Poll> { + let mut inner = self.inner.lock(); + let mut stream = inner.connection.recv_stream(*id); + let mut chunks = match stream.read(true) { + Ok(chunks) => chunks, + Err(ReadableError::UnknownStream) => { + return Poll::Ready(Err(QuicMuxerError::UnknownStream { id: *id })) + } + Err(ReadableError::IllegalOrderedRead) => { + panic!("Illegal ordered read can only happen if `stream.read(false)` is used."); + } + }; + let mut bytes = 0; + let mut pending = false; + loop { + if buf.is_empty() { + break; + } + match chunks.next(buf.len()) { + Ok(Some(chunk)) => { + buf.write_all(&chunk.bytes).expect("enough buffer space"); + bytes += chunk.bytes.len(); + } + Ok(None) => break, + Err(ReadError::Reset(error_code)) => { + tracing::debug!("substream {} was reset with error code {}", id, error_code); + bytes = 0; + break; + } + Err(ReadError::Blocked) => { + pending = true; + break; + } + } + } + if chunks.finalize().should_transmit() { + if let Some(waker) = inner.waker.take() { + waker.wake(); + } + } + let substream = inner.substreams.get_mut(id).unwrap(); + if pending && bytes == 0 { + substream.read_waker = Some(cx.waker().clone()); + Poll::Pending + } else { + Poll::Ready(Ok(bytes)) + } + } + + fn write_substream( + &self, + cx: &mut Context, + id: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { + let mut inner = self.inner.lock(); + match inner.connection.send_stream(*id).write(buf) { + Ok(bytes) => Poll::Ready(Ok(bytes)), + Err(WriteError::Blocked) => { + let mut substream = inner.substreams.get_mut(id).unwrap(); + substream.write_waker = Some(cx.waker().clone()); + Poll::Pending + } + Err(WriteError::Stopped(_)) => Poll::Ready(Ok(0)), + Err(WriteError::UnknownStream) => { + Poll::Ready(Err(QuicMuxerError::UnknownStream { id: *id })) + } + } + } + + fn shutdown_substream( + &self, + _: &mut Context, + id: &mut Self::Substream, + ) -> Poll> { + tracing::trace!("closing substream {}", id); + // closes the write end of the substream without waiting for the remote to receive the + // event. use flush substream to wait for the remote to receive the event. + let mut inner = self.inner.lock(); + match inner.connection.send_stream(*id).finish() { + Ok(()) => Poll::Ready(Ok(())), + Err(FinishError::Stopped(_)) => Poll::Ready(Ok(())), + Err(FinishError::UnknownStream) => { + Poll::Ready(Err(QuicMuxerError::UnknownStream { id: *id })) + } + } + } + + fn destroy_substream(&self, id: Self::Substream) { + tracing::trace!("destroying substream {}", id); + let mut inner = self.inner.lock(); + inner.substreams.remove(&id); + let mut stream = inner.connection.recv_stream(id); + let should_transmit = if let Ok(mut chunks) = stream.read(true) { + while let Ok(Some(_)) = chunks.next(usize::MAX) {} + chunks.finalize().should_transmit() + } else { + false + }; + if should_transmit { + if let Some(waker) = inner.waker.take() { + waker.wake(); + } + } + } + + fn flush_substream( + &self, + _cx: &mut Context, + _id: &mut Self::Substream, + ) -> Poll> { + // quinn doesn't support flushing, calling close will flush all substreams. + Poll::Ready(Ok(())) + } + + fn flush_all(&self, _cx: &mut Context) -> Poll> { + // quinn doesn't support flushing, calling close will flush all substreams. + Poll::Ready(Ok(())) + } + + fn close(&self, cx: &mut Context) -> Poll> { + tracing::trace!("closing muxer"); + let mut inner = self.inner.lock(); + if inner.substreams.is_empty() { + return Poll::Ready(Ok(())); + } + inner.close_waker = Some(cx.waker().clone()); + let inner = &mut *inner; + for id in inner.substreams.keys() { + let _ = inner.connection.send_stream(*id).finish(); + } + Poll::Pending + } +} + +#[derive(Debug, Error)] +pub enum QuicMuxerError { + #[error("connection was lost because of {reason}")] + ConnectionLost { reason: ConnectionError }, + #[error("unsupported quic feature used")] + ProtocolViolation, + #[error("stream {id} stopped with error {error_code}")] + StreamStopped { id: StreamId, error_code: VarInt }, + #[error("unknown stream {id}")] + UnknownStream { id: StreamId }, +} + +impl From for std::io::Error { + fn from(err: QuicMuxerError) -> Self { + std::io::Error::new(std::io::ErrorKind::Other, err) + } +} diff --git a/transports/quic/src/tls/certificate.rs b/transports/quic/src/tls/certificate.rs new file mode 100644 index 00000000000..59b6c2a4e77 --- /dev/null +++ b/transports/quic/src/tls/certificate.rs @@ -0,0 +1,76 @@ +// Copyright 2020 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Certificate handling for libp2p +//! +//! This module handles generation, signing, and verification of certificates. + +use super::LIBP2P_SIGNING_PREFIX_LENGTH; +use libp2p_core::identity::Keypair; + +const LIBP2P_OID: &[u64] = &[1, 3, 6, 1, 4, 1, 53594, 1, 1]; // Based on libp2p TLS 1.3 specs +const LIBP2P_SIGNATURE_ALGORITHM_PUBLIC_KEY_LENGTH: usize = 91; +static LIBP2P_SIGNATURE_ALGORITHM: &rcgen::SignatureAlgorithm = &rcgen::PKCS_ECDSA_P256_SHA256; + +/// Generates a self-signed TLS certificate that includes a libp2p-specific +/// certificate extension containing the public key of the given keypair. +pub(crate) fn make_cert(keypair: &Keypair) -> Result { + // Keypair used to sign the certificate. + let certif_keypair = rcgen::KeyPair::generate(LIBP2P_SIGNATURE_ALGORITHM)?; + + // The libp2p-specific extension to the certificate contains a signature of the public key + // of the certificate using the libp2p private key. + let libp2p_ext_signature = { + let certif_pubkey = certif_keypair.public_key_der(); + assert_eq!( + certif_pubkey.len(), + LIBP2P_SIGNATURE_ALGORITHM_PUBLIC_KEY_LENGTH, + ); + + let mut buf = + [0u8; LIBP2P_SIGNING_PREFIX_LENGTH + LIBP2P_SIGNATURE_ALGORITHM_PUBLIC_KEY_LENGTH]; + buf[..LIBP2P_SIGNING_PREFIX_LENGTH].copy_from_slice(&super::LIBP2P_SIGNING_PREFIX[..]); + buf[LIBP2P_SIGNING_PREFIX_LENGTH..].copy_from_slice(&certif_pubkey); + keypair.sign(&buf)? + }; + + // Generate the libp2p-specific extension. + let libp2p_extension: rcgen::CustomExtension = { + let extension_content = { + let serialized_pubkey = keypair.public().to_protobuf_encoding(); + yasna::encode_der(&(serialized_pubkey, libp2p_ext_signature)) + }; + + let mut ext = rcgen::CustomExtension::from_oid_content(LIBP2P_OID, extension_content); + ext.set_criticality(true); + ext + }; + + let certificate = { + let mut params = rcgen::CertificateParams::new(vec![]); + params.distinguished_name = rcgen::DistinguishedName::new(); + params.custom_extensions.push(libp2p_extension); + params.alg = LIBP2P_SIGNATURE_ALGORITHM; + params.key_pair = Some(certif_keypair); + rcgen::Certificate::from_params(params)? + }; + + Ok(certificate) +} diff --git a/transports/quic/src/tls/mod.rs b/transports/quic/src/tls/mod.rs new file mode 100644 index 00000000000..1cf8f3571b6 --- /dev/null +++ b/transports/quic/src/tls/mod.rs @@ -0,0 +1,86 @@ +// Copyright 2020 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! TLS configuration for `libp2p-quic`. + +mod certificate; +mod verifier; + +use libp2p_core::PeerId; +use std::sync::Arc; +use thiserror::Error; + +pub use verifier::extract_peerid_or_panic; + +const LIBP2P_SIGNING_PREFIX: [u8; 21] = *b"libp2p-tls-handshake:"; +const LIBP2P_SIGNING_PREFIX_LENGTH: usize = LIBP2P_SIGNING_PREFIX.len(); +const LIBP2P_OID_BYTES: &[u8] = &[43, 6, 1, 4, 1, 131, 162, 90, 1, 1]; // Based on libp2p TLS 1.3 specs + +/// Error creating a configuration +// TODO: remove this; what is the user supposed to do with this error? +#[derive(Debug, Error)] +pub enum ConfigError { + /// TLS private key or certificate rejected + #[error("TLS private or certificate key rejected: {0}")] + TLSError(#[from] rustls::TLSError), + /// Signing failed + #[error("Signing failed: {0}")] + SigningError(#[from] libp2p_core::identity::error::SigningError), + /// Certificate generation error + #[error("Certificate generation error: {0}")] + RcgenError(#[from] rcgen::RcgenError), +} + +pub fn make_client_config( + keypair: &libp2p_core::identity::Keypair, + remote_peer_id: PeerId, +) -> Result { + let cert = certificate::make_cert(keypair)?; + let private_key = cert.serialize_private_key_der(); + let cert = rustls::Certificate(cert.serialize_der()?); + let key = rustls::PrivateKey(private_key); + let verifier = verifier::Libp2pServerCertificateVerifier(remote_peer_id); + + let mut crypto = rustls::ClientConfig::new(); + crypto.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + crypto.alpn_protocols = vec![b"libp2p".to_vec()]; + crypto.enable_early_data = false; + crypto.set_single_client_cert(vec![cert], key)?; + crypto + .dangerous() + .set_certificate_verifier(Arc::new(verifier)); + Ok(crypto) +} + +pub fn make_server_config( + keypair: &libp2p_core::identity::Keypair, +) -> Result { + let cert = certificate::make_cert(keypair)?; + let private_key = cert.serialize_private_key_der(); + let cert = rustls::Certificate(cert.serialize_der()?); + let key = rustls::PrivateKey(private_key); + let verifier = verifier::Libp2pClientCertificateVerifier; + + let mut crypto = rustls::ServerConfig::new(Arc::new(verifier)); + crypto.versions = vec![rustls::ProtocolVersion::TLSv1_3]; + crypto.alpn_protocols = vec![b"libp2p".to_vec()]; + crypto.set_single_cert(vec![cert], key)?; + Ok(crypto) +} diff --git a/transports/quic/src/tls/verifier.rs b/transports/quic/src/tls/verifier.rs new file mode 100644 index 00000000000..9dcf2ed2298 --- /dev/null +++ b/transports/quic/src/tls/verifier.rs @@ -0,0 +1,244 @@ +// Copyright 2020 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use libp2p_core::identity::PublicKey; +use libp2p_core::PeerId; +use ring::io::der; +use rustls::{ + internal::msgs::handshake::DigitallySignedStruct, Certificate, ClientCertVerified, + HandshakeSignatureValid, ServerCertVerified, TLSError, +}; +use untrusted::{Input, Reader}; +use webpki::Error; + +/// Implementation of the `rustls` certificate verification traits for libp2p. +/// +/// Only TLS 1.3 is supported. TLS 1.2 should be disabled in the configuration of `rustls`. +pub(crate) struct Libp2pServerCertificateVerifier(pub(crate) PeerId); + +/// libp2p requires the following of X.509 server certificate chains: +/// +/// - Exactly one certificate must be presented. +/// - The certificate must be self-signed. +/// - The certificate must have a valid libp2p extension that includes a +/// signature of its public key. +impl rustls::ServerCertVerifier for Libp2pServerCertificateVerifier { + fn verify_server_cert( + &self, + _roots: &rustls::RootCertStore, + presented_certs: &[rustls::Certificate], + _dns_name: webpki::DNSNameRef<'_>, + _ocsp_response: &[u8], + ) -> Result { + let peer_id = verify_presented_certs(presented_certs)?; + if peer_id != self.0 { + return Err(TLSError::PeerIncompatibleError( + "Unexpected peer id".to_string(), + )); + } + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + Err(TLSError::PeerIncompatibleError( + "Only TLS 1.3 certificates are supported".to_string(), + )) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls13_signature(message, cert, dss) + } +} + +/// Implementation of the `rustls` certificate verification traits for libp2p. +/// +/// Only TLS 1.3 is supported. TLS 1.2 should be disabled in the configuration of `rustls`. +pub(crate) struct Libp2pClientCertificateVerifier; + +/// libp2p requires the following of X.509 client certificate chains: +/// +/// - Exactly one certificate must be presented. In particular, client +/// authentication is mandatory in libp2p. +/// - The certificate must be self-signed. +/// - The certificate must have a valid libp2p extension that includes a +/// signature of its public key. +impl rustls::ClientCertVerifier for Libp2pClientCertificateVerifier { + fn offer_client_auth(&self) -> bool { + true + } + + fn client_auth_root_subjects( + &self, + _dns_name: Option<&webpki::DNSName>, + ) -> Option { + Some(vec![]) + } + + fn verify_client_cert( + &self, + presented_certs: &[Certificate], + _dns_name: Option<&webpki::DNSName>, + ) -> Result { + verify_presented_certs(presented_certs).map(|_| ClientCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + Err(TLSError::PeerIncompatibleError( + "Only TLS 1.3 certificates are supported".to_string(), + )) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &DigitallySignedStruct, + ) -> Result { + barebones_x509::parse_certificate(cert.as_ref()) + .map_err(rustls::TLSError::WebPKIError)? + .check_tls13_signature(dss.scheme, message, dss.sig.0.as_ref()) + .map_err(rustls::TLSError::WebPKIError) + .map(|()| rustls::HandshakeSignatureValid::assertion()) + } +} + +fn verify_tls13_signature( + message: &[u8], + cert: &Certificate, + dss: &DigitallySignedStruct, +) -> Result { + barebones_x509::parse_certificate(cert.as_ref()) + .map_err(rustls::TLSError::WebPKIError)? + .check_tls13_signature(dss.scheme, message, dss.sig.0.as_ref()) + .map_err(rustls::TLSError::WebPKIError) + .map(|()| rustls::HandshakeSignatureValid::assertion()) +} + +fn verify_libp2p_signature( + libp2p_extension: &Libp2pExtension<'_>, + x509_pkey_bytes: &[u8], +) -> Result<(), Error> { + let mut v = Vec::with_capacity(super::LIBP2P_SIGNING_PREFIX_LENGTH + x509_pkey_bytes.len()); + v.extend_from_slice(&super::LIBP2P_SIGNING_PREFIX[..]); + v.extend_from_slice(x509_pkey_bytes); + if libp2p_extension + .peer_key + .verify(&v, libp2p_extension.signature) + { + Ok(()) + } else { + Err(Error::UnknownIssuer) + } +} + +fn parse_certificate( + certificate: &[u8], +) -> Result<(barebones_x509::X509Certificate<'_>, Libp2pExtension<'_>), Error> { + let parsed = barebones_x509::parse_certificate(certificate)?; + let mut libp2p_extension = None; + + parsed + .extensions() + .iterate(&mut |oid, critical, extension| { + match oid { + super::LIBP2P_OID_BYTES if libp2p_extension.is_some() => return Err(Error::BadDER), + super::LIBP2P_OID_BYTES => { + libp2p_extension = Some(parse_libp2p_extension(extension)?) + } + _ if critical => return Err(Error::UnsupportedCriticalExtension), + _ => {} + }; + Ok(()) + })?; + let libp2p_extension = libp2p_extension.ok_or(Error::UnknownIssuer)?; + Ok((parsed, libp2p_extension)) +} + +fn verify_presented_certs(presented_certs: &[Certificate]) -> Result { + if presented_certs.len() != 1 { + return Err(TLSError::NoCertificatesPresented); + } + let (certificate, extension) = + parse_certificate(presented_certs[0].as_ref()).map_err(TLSError::WebPKIError)?; + certificate.valid().map_err(TLSError::WebPKIError)?; + certificate + .check_self_issued() + .map_err(TLSError::WebPKIError)?; + verify_libp2p_signature(&extension, certificate.subject_public_key_info().spki()) + .map_err(TLSError::WebPKIError)?; + Ok(PeerId::from_public_key(&extension.peer_key)) +} + +struct Libp2pExtension<'a> { + peer_key: PublicKey, + signature: &'a [u8], +} + +fn parse_libp2p_extension(extension: Input<'_>) -> Result, Error> { + fn read_bit_string<'a>(input: &mut Reader<'a>, e: Error) -> Result, Error> { + // The specification states that this is a BIT STRING, but the Go implementation + // uses an OCTET STRING. OCTET STRING is superior in this context, so use it. + der::expect_tag_and_get_value(input, der::Tag::OctetString).map_err(|_| e) + } + + let e = Error::ExtensionValueInvalid; + Input::read_all(&extension, e, |input| { + der::nested(input, der::Tag::Sequence, e, |input| { + let public_key = read_bit_string(input, e)?.as_slice_less_safe(); + let signature = read_bit_string(input, e)?.as_slice_less_safe(); + // We deliberately discard the error information because this is + // either a broken peer or an attack. + let peer_key = PublicKey::from_protobuf_encoding(public_key).map_err(|_| e)?; + Ok(Libp2pExtension { + peer_key, + signature, + }) + }) + }) +} + +/// Extracts the `PeerId` from a certificate’s libp2p extension. It is erroneous +/// to call this unless the certificate is known to be a well-formed X.509 +/// certificate with a valid libp2p extension. The certificate verifier in this +/// module check this. +/// +/// # Panics +/// +/// Panics if called on an invalid certificate. +pub fn extract_peerid_or_panic(certificate: &[u8]) -> PeerId { + let r = parse_certificate(certificate) + .expect("we already checked that the certificate was valid during the handshake; qed"); + PeerId::from_public_key(&r.1.peer_key) +} diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs new file mode 100644 index 00000000000..0cd77026634 --- /dev/null +++ b/transports/quic/src/transport.rs @@ -0,0 +1,385 @@ +use crate::crypto::Crypto; +use crate::endpoint::{EndpointConfig, TransportChannel}; +use crate::muxer::QuicMuxer; +use crate::{QuicConfig, QuicError}; +use futures::channel::oneshot; +use futures::prelude::*; +use if_watch::{IfEvent, IfWatcher}; +use libp2p_core::multiaddr::{Multiaddr, Protocol}; +use libp2p_core::muxing::{StreamMuxer, StreamMuxerBox}; +use libp2p_core::transport::{Boxed, ListenerEvent, Transport, TransportError}; +use libp2p_core::PeerId; +use parking_lot::Mutex; +use quinn_proto::crypto::Session; +use std::net::{IpAddr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use udp_socket::SocketType; + +#[derive(Clone)] +pub struct QuicTransport { + inner: Arc>>, +} + +impl QuicTransport +where + ::ClientConfig: Send + Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + /// Creates a new quic transport. + pub async fn new( + config: QuicConfig, + addr: Multiaddr, + ) -> Result> { + let socket_addr = multiaddr_to_socketaddr::(&addr) + .map_err(|_| TransportError::MultiaddrNotSupported(addr.clone()))? + .0; + let addresses = if socket_addr.ip().is_unspecified() { + let watcher = IfWatcher::new() + .await + .map_err(|err| TransportError::Other(err.into()))?; + Addresses::Unspecified(watcher) + } else { + Addresses::Ip(Some(socket_addr.ip())) + }; + let endpoint = EndpointConfig::new(config, socket_addr).map_err(TransportError::Other)?; + Ok(Self { + inner: Arc::new(Mutex::new(QuicTransportInner { + channel: endpoint.spawn(), + addresses, + })), + }) + } + + /// Creates a boxed libp2p transport. + pub fn boxed(self) -> Boxed<(PeerId, StreamMuxerBox)> { + Transport::map(self, |(peer_id, muxer), _| { + (peer_id, StreamMuxerBox::new(muxer)) + }) + .boxed() + } +} + +impl std::fmt::Debug for QuicTransport { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("QuicTransport").finish() + } +} + +struct QuicTransportInner { + channel: TransportChannel, + addresses: Addresses, +} + +enum Addresses { + Unspecified(IfWatcher), + Ip(Option), +} + +impl Transport for QuicTransport +where + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + type Output = (PeerId, QuicMuxer); + type Error = QuicError; + type Listener = Self; + type ListenerUpgrade = QuicUpgrade; + type Dial = QuicDial; + + fn listen_on(self, addr: Multiaddr) -> Result> { + multiaddr_to_socketaddr::(&addr) + .map_err(|_| TransportError::MultiaddrNotSupported(addr))?; + Ok(self) + } + + fn dial(self, addr: Multiaddr) -> Result> { + let (socket_addr, public_key) = + if let Ok((socket_addr, Some(public_key))) = multiaddr_to_socketaddr::(&addr) { + (socket_addr, public_key) + } else { + tracing::debug!("invalid multiaddr"); + return Err(TransportError::MultiaddrNotSupported(addr.clone())); + }; + if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() { + tracing::debug!("invalid multiaddr"); + return Err(TransportError::MultiaddrNotSupported(addr)); + } + tracing::debug!("dialing {}", socket_addr); + let rx = self.inner.lock().channel.dial(socket_addr, public_key); + Ok(QuicDial::Dialing(rx)) + } + + fn address_translation(&self, _listen: &Multiaddr, observed: &Multiaddr) -> Option { + Some(observed.clone()) + } +} + +impl Stream for QuicTransport { + type Item = Result, QuicError>, QuicError>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut inner = self.inner.lock(); + match &mut inner.addresses { + Addresses::Ip(ip) => { + if let Some(ip) = ip.take() { + let addr = socketaddr_to_multiaddr(&SocketAddr::new(ip, inner.channel.port())); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(addr)))); + } + } + Addresses::Unspecified(watcher) => match Pin::new(watcher).poll(cx) { + Poll::Ready(Ok(IfEvent::Up(net))) => { + if inner.channel.ty() == SocketType::Ipv4 && net.addr().is_ipv4() + || inner.channel.ty() != SocketType::Ipv4 && net.addr().is_ipv6() + { + let addr = socketaddr_to_multiaddr(&SocketAddr::new( + net.addr(), + inner.channel.port(), + )); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(addr)))); + } + } + Poll::Ready(Ok(IfEvent::Down(net))) => { + if inner.channel.ty() == SocketType::Ipv4 && net.addr().is_ipv4() + || inner.channel.ty() != SocketType::Ipv4 && net.addr().is_ipv6() + { + let addr = socketaddr_to_multiaddr(&SocketAddr::new( + net.addr(), + inner.channel.port(), + )); + return Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(addr)))); + } + } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err.into()))), + Poll::Pending => {} + }, + } + match inner.channel.poll_incoming(cx) { + Poll::Ready(Some(Ok(muxer))) => Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + local_addr: muxer.local_addr(), + remote_addr: muxer.remote_addr(), + upgrade: QuicUpgrade::new(muxer), + }))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[allow(clippy::large_enum_variant)] +pub enum QuicDial { + Dialing(oneshot::Receiver, QuicError>>), + Upgrade(QuicUpgrade), +} + +impl Future for QuicDial +where + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + type Output = Result<(PeerId, QuicMuxer), QuicError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + loop { + match &mut *self { + Self::Dialing(rx) => match Pin::new(rx).poll(cx) { + Poll::Ready(Ok(Ok(muxer))) => { + *self = Self::Upgrade(QuicUpgrade::new(muxer)); + } + Poll::Ready(Ok(Err(err))) => return Poll::Ready(Err(err)), + Poll::Ready(Err(_)) => panic!("endpoint crashed"), + Poll::Pending => return Poll::Pending, + }, + Self::Upgrade(upgrade) => return Pin::new(upgrade).poll(cx), + } + } + } +} + +pub struct QuicUpgrade { + muxer: Option>, +} + +impl QuicUpgrade { + fn new(muxer: QuicMuxer) -> Self { + Self { muxer: Some(muxer) } + } +} + +impl Future for QuicUpgrade +where + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + type Output = Result<(PeerId, QuicMuxer), QuicError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = Pin::into_inner(self); + let muxer = inner.muxer.as_mut().expect("future polled after ready"); + match muxer.poll_event(cx) { + Poll::Pending => { + if let Some(peer_id) = muxer.peer_id() { + muxer.set_accept_incoming(true); + Poll::Ready(Ok(( + peer_id, + inner.muxer.take().expect("future polled after ready"), + ))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), + Poll::Ready(Ok(_)) => { + unreachable!("muxer.incoming is set to false so no events can be produced"); + } + } + } +} + +/// Tries to turn a QUIC multiaddress into a UDP [`SocketAddr`]. Returns an error if the format +/// of the multiaddr is wrong. +fn multiaddr_to_socketaddr( + addr: &Multiaddr, +) -> Result<(SocketAddr, Option), ()> { + let mut iter = addr.iter().peekable(); + let proto1 = iter.next().ok_or(())?; + let proto2 = iter.next().ok_or(())?; + let proto3 = iter.next().ok_or(())?; + + let peer_id = if let Some(Protocol::P2p(peer_id)) = iter.peek() { + if peer_id.code() != multihash::Code::Identity.into() { + return Err(()); + } + let public_key = + libp2p_core::PublicKey::from_protobuf_encoding(peer_id.digest()).map_err(|_| ())?; + let public_key = C::extract_public_key(public_key).ok_or(())?; + iter.next(); + Some(public_key) + } else { + None + }; + + if iter.next().is_some() { + return Err(()); + } + + match (proto1, proto2, proto3) { + (Protocol::Ip4(ip), Protocol::Udp(port), Protocol::Quic) => { + Ok((SocketAddr::new(ip.into(), port), peer_id)) + } + (Protocol::Ip6(ip), Protocol::Udp(port), Protocol::Quic) => { + Ok((SocketAddr::new(ip.into(), port), peer_id)) + } + _ => Err(()), + } +} + +/// Turns an IP address and port into the corresponding QUIC multiaddr. +pub(crate) fn socketaddr_to_multiaddr(socket_addr: &SocketAddr) -> Multiaddr { + Multiaddr::empty() + .with(socket_addr.ip().into()) + .with(Protocol::Udp(socket_addr.port())) + .with(Protocol::Quic) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn multiaddr_to_udp_conversion() { + use std::net::{Ipv4Addr, Ipv6Addr}; + + assert!(multiaddr_to_socketaddr::( + &"/ip4/127.0.0.1/udp/1234".parse::().unwrap() + ) + .is_err()); + + assert_eq!( + multiaddr_to_socketaddr::( + &"/ip4/127.0.0.1/udp/12345/quic" + .parse::() + .unwrap() + ), + Ok(( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345,), + None + )) + ); + assert_eq!( + multiaddr_to_socketaddr::( + &"/ip4/255.255.255.255/udp/8080/quic" + .parse::() + .unwrap() + ), + Ok(( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 8080,), + None + )) + ); + assert_eq!( + multiaddr_to_socketaddr::(&"/ip6/::1/udp/12345/quic".parse::().unwrap()), + Ok(( + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 12345,), + None + )) + ); + assert_eq!( + multiaddr_to_socketaddr::( + &"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/udp/8080/quic" + .parse::() + .unwrap() + ), + Ok(( + SocketAddr::new( + IpAddr::V6(Ipv6Addr::new( + 65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535, + )), + 8080, + ), + None + )) + ); + } + + #[cfg(feature = "noise")] + #[test] + fn multiaddr_to_udp_noise() { + multiaddr_to_udp_conversion::(); + } + #[cfg(feature = "tls")] + #[test] + fn multiaddr_to_udp_tls() { + multiaddr_to_udp_conversion::(); + } + + fn multiaddr_to_pk_conversion(keypair: C::Keypair) { + use crate::crypto::ToLibp2p; + use std::net::Ipv4Addr; + + let peer_id = keypair.to_public().to_peer_id(); + let addr = String::from("/ip4/127.0.0.1/udp/12345/quic/p2p/") + &peer_id.to_base58(); + assert_eq!( + multiaddr_to_socketaddr::(&addr.parse::().unwrap()), + Ok(( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345,), + C::extract_public_key(keypair.to_public()) + )) + ); + } + + #[cfg(feature = "tls")] + #[test] + fn multiaddr_to_pk_tls() { + let keypair = libp2p_core::identity::Keypair::generate_ed25519(); + multiaddr_to_pk_conversion::(keypair); + } + #[cfg(feature = "noise")] + #[test] + fn multiaddr_to_pk_noise() { + let keypair = ed25519_dalek::Keypair::generate(&mut rand_core::OsRng {}); + multiaddr_to_pk_conversion::(keypair); + } +} diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs new file mode 100644 index 00000000000..3ee821e956a --- /dev/null +++ b/transports/quic/tests/smoke.rs @@ -0,0 +1,392 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::future::FutureExt; +use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use futures::stream::StreamExt; +use libp2p::core::upgrade; +use libp2p::multiaddr::Protocol; +use libp2p::request_response::{ + ProtocolName, ProtocolSupport, RequestResponse, RequestResponseCodec, RequestResponseConfig, + RequestResponseEvent, RequestResponseMessage, +}; +use libp2p::swarm::{Swarm, SwarmEvent}; +use libp2p_quic::{Crypto, QuicConfig, ToLibp2p}; +use quinn_proto::crypto::Session; +use rand::RngCore; +use std::{io, iter}; + +#[cfg(feature = "noise")] +fn generate_noise_keypair() -> ed25519_dalek::Keypair { + ed25519_dalek::Keypair::generate(&mut rand_core::OsRng {}) +} +#[cfg(feature = "tls")] +fn generate_tls_keypair() -> libp2p::identity::Keypair { + libp2p::identity::Keypair::generate_ed25519() +} + +#[cfg(feature = "noise")] +#[async_std::test] +async fn smoke_noise() -> Result<()> { + smoke::().await +} + +#[cfg(feature = "tls")] +#[async_std::test] +async fn smoke_tls() -> Result<()> { + smoke::().await +} + +trait GenerateKeypair: Crypto { + fn generate_keypair() -> Self::Keypair; +} + +#[cfg(feature = "noise")] +impl GenerateKeypair for libp2p_quic::NoiseCrypto { + fn generate_keypair() -> Self::Keypair { + generate_noise_keypair() + } +} + +#[cfg(feature = "tls")] +impl GenerateKeypair for libp2p_quic::TlsCrypto { + fn generate_keypair() -> Self::Keypair { + generate_tls_keypair() + } +} + +async fn create_swarm( + keylog: bool, +) -> Result>> +where + ::ClientConfig: Send + Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + let keypair = C::generate_keypair(); + let peer_id = keypair.to_peer_id(); + let mut transport = QuicConfig::::new(keypair); + if keylog { + transport.enable_keylogger(); + } + let transport = transport + .listen_on("/ip4/127.0.0.1/udp/0/quic".parse()?) + .await? + .boxed(); + + let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); + let cfg = RequestResponseConfig::default(); + let behaviour = RequestResponse::new(PingCodec(), protocols, cfg); + tracing::info!("{}", peer_id); + Ok(Swarm::new(transport, behaviour, peer_id)) +} + +async fn smoke() -> Result<()> +where + ::ClientConfig: Send + Unpin, + ::HeaderKey: Unpin, + ::PacketKey: Unpin, +{ + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .ok(); + log_panics::init(); + let mut rng = rand::thread_rng(); + + let mut a = create_swarm::(true).await?; + let mut b = create_swarm::(false).await?; + + Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0/quic".parse()?)?; + + let mut addr = match a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + addr.push(Protocol::P2p((*a.local_peer_id()).into())); + + let mut data = vec![0; 4096 * 10]; + rng.fill_bytes(&mut data); + + b.behaviour_mut() + .add_address(&Swarm::local_peer_id(&a), addr); + b.behaviour_mut() + .send_request(&Swarm::local_peer_id(&a), Ping(data.clone())); + + match b.next().await { + Some(SwarmEvent::Dialing(_)) => {} + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::IncomingConnection { .. }) => {} + e => panic!("{:?}", e), + }; + + match b.next().await { + Some(SwarmEvent::ConnectionEstablished { .. }) => {} + e => panic!("{:?}", e), + }; + + match a.next().await { + Some(SwarmEvent::ConnectionEstablished { .. }) => {} + e => panic!("{:?}", e), + }; + + assert!(b.next().now_or_never().is_none()); + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Request { + request: Ping(ping), + channel, + .. + }, + .. + })) => { + a.behaviour_mut() + .send_response(channel, Pong(ping)) + .unwrap(); + } + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {} + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Response { + response: Pong(pong), + .. + }, + .. + })) => assert_eq!(data, pong), + e => panic!("{:?}", e), + } + + a.behaviour_mut().send_request( + &Swarm::local_peer_id(&b), + Ping(b"another substream".to_vec()), + ); + + assert!(a.next().now_or_never().is_none()); + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Request { + request: Ping(data), + channel, + .. + }, + .. + })) => { + b.behaviour_mut() + .send_response(channel, Pong(data)) + .unwrap(); + } + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { .. })) => {} + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::Message { + message: + RequestResponseMessage::Response { + response: Pong(data), + .. + }, + .. + })) => assert_eq!(data, b"another substream".to_vec()), + e => panic!("{:?}", e), + } + + Ok(()) +} + +#[derive(Debug, Clone)] +struct PingProtocol(); + +#[derive(Clone)] +struct PingCodec(); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Ping(Vec); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Pong(Vec); + +impl ProtocolName for PingProtocol { + fn protocol_name(&self) -> &[u8] { + "/ping/1".as_bytes() + } +} + +#[async_trait] +impl RequestResponseCodec for PingCodec { + type Protocol = PingProtocol; + type Request = Ping; + type Response = Pong; + + async fn read_request(&mut self, _: &PingProtocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + upgrade::read_length_prefixed(io, 4096 * 10) + .map(|res| match res { + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()), + Ok(vec) => Ok(Ping(vec)), + }) + .await + } + + async fn read_response(&mut self, _: &PingProtocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send, + { + upgrade::read_length_prefixed(io, 4096 * 10) + .map(|res| match res { + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + Ok(vec) if vec.is_empty() => Err(io::ErrorKind::UnexpectedEof.into()), + Ok(vec) => Ok(Pong(vec)), + }) + .await + } + + async fn write_request( + &mut self, + _: &PingProtocol, + io: &mut T, + Ping(data): Ping, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + upgrade::write_length_prefixed(io, data).await?; + io.close().await?; + Ok(()) + } + + async fn write_response( + &mut self, + _: &PingProtocol, + io: &mut T, + Pong(data): Pong, + ) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + { + upgrade::write_length_prefixed(io, data).await?; + io.close().await?; + Ok(()) + } +} + +#[cfg(feature = "noise")] +#[async_std::test] +async fn dial_failure_noise() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .ok(); + log_panics::init(); + + let mut a = create_swarm::(true).await?; + let mut b = create_swarm::(false).await?; + + Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0/quic".parse()?)?; + + let keypair = libp2p_quic::NoiseCrypto::generate_keypair(); + let fake_peer_id = keypair.to_peer_id(); + + let mut addr = match a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + addr.push(Protocol::P2p(fake_peer_id.into())); + + b.behaviour_mut().add_address(&fake_peer_id, addr); + b.behaviour_mut() + .send_request(&fake_peer_id, Ping(b"hello world".to_vec())); + + match b.next().await { + Some(SwarmEvent::Dialing(_)) => {} + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::ConnectionEstablished { .. }) => {} + e => panic!("{:?}", e), + }; + + match b.next().await { + Some(SwarmEvent::ConnectionClosed { .. }) => {} + e => panic!("{:?}", e), + }; + + assert!(a.next().now_or_never().is_none()); + + Ok(()) +} + +#[cfg(feature = "tls")] +#[async_std::test] +async fn dial_failure_tls() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .ok(); + log_panics::init(); + + let mut a = create_swarm::(true).await?; + let mut b = create_swarm::(false).await?; + + Swarm::listen_on(&mut a, "/ip4/127.0.0.1/udp/0/quic".parse()?)?; + + let keypair = libp2p_quic::TlsCrypto::generate_keypair(); + let fake_peer_id = keypair.to_peer_id(); + + let mut addr = match a.next().await { + Some(SwarmEvent::NewListenAddr { address, .. }) => address, + e => panic!("{:?}", e), + }; + addr.push(Protocol::P2p(fake_peer_id.into())); + + b.behaviour_mut().add_address(&fake_peer_id, addr); + b.behaviour_mut() + .send_request(&fake_peer_id, Ping(b"hello world".to_vec())); + + match b.next().await { + Some(SwarmEvent::Dialing(_)) => {} + e => panic!("{:?}", e), + } + + match a.next().await { + Some(SwarmEvent::IncomingConnection { .. }) => {} + e => panic!("{:?}", e), + } + + match b.next().await { + Some(SwarmEvent::UnreachableAddr { .. }) => {} + e => panic!("{:?}", e), + }; + + match b.next().await { + Some(SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure { .. })) => {} + e => panic!("{:?}", e), + }; + + assert!(a.next().await.is_some()); // ConnectionClosed + assert!(a.next().now_or_never().is_none()); + + Ok(()) +}