Skip to content

Commit

Permalink
fix(transport): Handle tls accepting on task (#320)
Browse files Browse the repository at this point in the history
Signed-off-by: Lucio Franco <luciofranco14@gmail.com>
  • Loading branch information
LucioFranco authored Apr 1, 2020
1 parent e5b1f8a commit 04a8c0c
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 27 deletions.
29 changes: 18 additions & 11 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#[cfg(feature = "tls")]
use super::TlsStream;
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};
use tokio_rustls::rustls::Session;

/// Trait that connected IO resources implement.
///
Expand Down Expand Up @@ -37,19 +39,24 @@ impl Connected for TcpStream {
#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
let (inner, _) = self.get_ref();
inner.remote_addr()
if let Some((inner, _)) = self.get_ref() {
inner.remote_addr()
} else {
None
}
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let (_, session) = self.get_ref();

if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
if let Some((_, session)) = self.get_ref() {
if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
} else {
None
}
} else {
None
}
Expand Down
113 changes: 104 additions & 9 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ use std::{
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use tracing::error;

#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
pub(crate) fn tcp_incoming<IO, IE>(
Expand All @@ -32,13 +30,7 @@ where
#[cfg(feature = "tls")]
{
if let Some(tls) = &server.tls {
let io = match tls.accept(stream).await {
Ok(io) => io,
Err(error) => {
error!(message = "Unable to accept incoming connection.", %error);
continue
},
};
let io = tls.accept(stream);
yield ServerIo::new(io);
continue;
}
Expand Down Expand Up @@ -73,3 +65,106 @@ impl Stream for TcpIncoming {
Pin::new(&mut self.inner).poll_accept(cx)
}
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
#[cfg(feature = "tls")]
pub(crate) struct TlsStream<IO> {
state: State<IO>,
}

#[cfg(feature = "tls")]
enum State<IO> {
Handshaking(tokio_rustls::Accept<IO>),
Streaming(tokio_rustls::server::TlsStream<IO>),
}

#[cfg(feature = "tls")]
impl<IO> TlsStream<IO> {
pub(crate) fn new(accept: tokio_rustls::Accept<IO>) -> Self {
TlsStream {
state: State::Handshaking(accept),
}
}

pub(crate) fn get_ref(&self) -> Option<(&IO, &tokio_rustls::rustls::ServerSession)> {
if let State::Streaming(tls) = &self.state {
Some(tls.get_ref())
} else {
None
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
3 changes: 3 additions & 0 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use super::service::TlsAcceptor;

use incoming::TcpIncoming;

#[cfg(feature = "tls")]
pub(crate) use incoming::TlsStream;

use super::service::{Or, Routes, ServerIo, ServiceBuilderExt};
use crate::{body::BoxBody, request::ConnectionInfo};
use futures_core::Stream;
Expand Down
14 changes: 7 additions & 7 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::io::BoxedIo;
use crate::transport::{server::Connected, Certificate, Identity};
use crate::transport::{
server::{Connected, TlsStream},
Certificate, Identity,
};
#[cfg(feature = "tls-roots")]
use rustls_native_certs;
use std::{fmt, sync::Arc};
Expand Down Expand Up @@ -157,17 +160,14 @@ impl TlsAcceptor {
})
}

pub(crate) async fn accept<IO>(
&self,
io: IO,
) -> Result<tokio_rustls::server::TlsStream<IO>, crate::Error>
pub(crate) fn accept<IO>(&self, io: IO) -> TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
let tls = acceptor.accept(io).await?;
let accept = acceptor.accept(io);

Ok(tls)
TlsStream::new(accept)
}
}

Expand Down

0 comments on commit 04a8c0c

Please sign in to comment.