Skip to content

Commit

Permalink
feat(transport): Fix TLS accept w/ peer certs (#535)
Browse files Browse the repository at this point in the history
* feat(transport): Fix TLS accept w/ peer certs

* fix unused var

* fix feature flag imports

* spawn accept task
  • Loading branch information
LucioFranco authored Jan 15, 2021
1 parent 4974604 commit 41c51f1
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 138 deletions.
8 changes: 5 additions & 3 deletions examples/src/tls_client_auth/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ pub struct EchoServer;
#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
if let Some(certs) = request.peer_certs() {
println!("Got {} peer certs!", certs.len());
}
let certs = request
.peer_certs()
.expect("Client did not send its certs!");

println!("Got {} peer certs!", certs.len());

let message = request.into_inner().message;
Ok(Response::new(EchoResponse { message }))
Expand Down
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ transport = [
"tokio",
"tower",
"tracing-futures",
"tokio/macros"
]
tls = ["transport", "tokio-rustls"]
tls-roots = ["tls", "rustls-native-certs"]
Expand Down
30 changes: 12 additions & 18 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#[cfg(feature = "tls")]
use super::TlsStream;
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;
#[cfg(feature = "tls")]
use tokio_rustls::rustls::Session;
use tokio_rustls::{rustls::Session, server::TlsStream};

/// Trait that connected IO resources implement.
///
Expand Down Expand Up @@ -39,24 +37,20 @@ impl Connected for TcpStream {
#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
if let Some((inner, _)) = self.get_ref() {
inner.remote_addr()
} else {
None
}
let (inner, _) = self.get_ref();

inner.remote_addr()
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
if let Some((_, session)) = self.get_ref() {
if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
} else {
None
}
let (_, session) = self.get_ref();

if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
} else {
None
}
Expand Down
209 changes: 97 additions & 112 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use std::{
};
use tokio::io::{AsyncRead, AsyncWrite};

#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
#[cfg(not(feature = "tls"))]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
_server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
Expand All @@ -26,145 +26,130 @@ where
async_stream::try_stream! {
futures_util::pin_mut!(incoming);


while let Some(stream) = incoming.try_next().await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &server.tls {
let io = tls.accept(stream);
yield ServerIo::new(io);
continue;
}
}

yield ServerIo::new(stream);
}
}
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}
#[cfg(feature = "tls")]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
futures_util::pin_mut!(incoming);

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}
#[cfg(feature = "tls")]
let mut tasks = futures_util::stream::futures_unordered::FuturesUnordered::new();

impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;
loop {
match select(&mut incoming, &mut tasks).await {
SelectOutput::Incoming(stream) => {
if let Some(tls) = &server.tls {
let tls = tls.clone();

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
let accept = tokio::spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new(io))
});

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
#[cfg(feature = "tls")]
pub(crate) struct TlsStream<IO> {
state: State<IO>,
}
tasks.push(accept);
} else {
yield ServerIo::new(stream);
}
}

#[cfg(feature = "tls")]
enum State<IO> {
Handshaking(tokio_rustls::Accept<IO>),
Streaming(tokio_rustls::server::TlsStream<IO>),
}
SelectOutput::Io(io) => {
yield io;
}

#[cfg(feature = "tls")]
impl<IO> TlsStream<IO> {
pub(crate) fn new(accept: tokio_rustls::Accept<IO>) -> Self {
TlsStream {
state: State::Handshaking(accept),
}
}
SelectOutput::Err(e) => {
tracing::error!(message = "Accept loop error.", error = %e);
}

pub(crate) fn get_ref(&self) -> Option<(&IO, &tokio_rustls::rustls::ServerSession)> {
if let State::Streaming(tls) = &self.state {
Some(tls.get_ref())
} else {
None
SelectOutput::Done => {
break;
}
}
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncRead for TlsStream<IO>
async fn select<IO, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered<
tokio::task::JoinHandle<Result<ServerIo, crate::Error>>,
>,
) -> SelectOutput<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
use futures_util::StreamExt;

if tasks.is_empty() {
return match incoming.try_next().await {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
};
}

tokio::select! {
stream = incoming.try_next() => {
match stream {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
accept = tasks.next() => {
match accept.expect("FuturesUnordered stream should never end") {
Ok(Ok(io)) => SelectOutput::Io(io),
Ok(Err(e)) => SelectOutput::Err(e),
Err(e) => SelectOutput::Err(e.into()),
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
#[cfg(feature = "tls")]
enum SelectOutput<A> {
Incoming(A),
Io(ServerIo),
Err(crate::Error),
Done,
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
2 changes: 1 addition & 1 deletion tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::service::TlsAcceptor;
use incoming::TcpIncoming;

#[cfg(feature = "tls")]
pub(crate) use incoming::TlsStream;
pub(crate) use tokio_rustls::server::TlsStream;

#[cfg(feature = "tls")]
use crate::transport::Error;
Expand Down
6 changes: 2 additions & 4 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,12 @@ impl TlsAcceptor {
})
}

pub(crate) fn accept<IO>(&self, io: IO) -> TlsStream<IO>
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
let accept = acceptor.accept(io);

TlsStream::new(accept)
acceptor.accept(io).await.map_err(Into::into)
}
}

Expand Down

0 comments on commit 41c51f1

Please sign in to comment.