diff --git a/Cargo.toml b/Cargo.toml index 857868c5..035b8d70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ description = "An RPC framework for Rust with a focus on ease of use." travis-ci = { repository = "google/tarpc" } [dependencies] -bincode = "1.0.0-alpha2" +bincode = "1.0.0-alpha4" byteorder = "1.0" cfg-if = "0.1.0" futures = "0.1.7" diff --git a/README.md b/README.md index da525aa0..d1d15e51 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,8 @@ extern crate tarpc; use std::sync::mpsc; use std::thread; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt; +use tarpc::sync::{client, server}; +use tarpc::sync::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; service! { @@ -109,8 +109,8 @@ extern crate tarpc; extern crate tokio_core; use futures::Future; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; @@ -180,8 +180,9 @@ extern crate tarpc; extern crate tokio_core; use futures::Future; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; +use tarpc::tls; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; use tarpc::native_tls::{Pkcs12, TlsAcceptor}; @@ -216,7 +217,7 @@ fn main() { reactor.handle().spawn(server); let options = client::Options::default() .handle(reactor.handle()) - .tls(client::tls::Context::new("foobar.com").unwrap()); + .tls(tls::client::Context::new("foobar.com").unwrap()); reactor.run(FutureClient::connect(handle.addr(), options) .map_err(tarpc::Error::from) .and_then(|client| client.hello("Mom".to_string())) diff --git a/benches/latency.rs b/benches/latency.rs index c8357ed7..75ecbd76 100644 --- a/benches/latency.rs +++ b/benches/latency.rs @@ -14,8 +14,8 @@ extern crate env_logger; extern crate futures; extern crate tokio_core; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; #[cfg(test)] use test::Bencher; diff --git a/examples/concurrency.rs b/examples/concurrency.rs index 767057f6..e49fb8d1 100644 --- a/examples/concurrency.rs +++ b/examples/concurrency.rs @@ -25,8 +25,8 @@ use std::{cmp, thread}; use std::sync::{Arc, mpsc}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; diff --git a/examples/pubsub.rs b/examples/pubsub.rs index 7435314e..e198fa8d 100644 --- a/examples/pubsub.rs +++ b/examples/pubsub.rs @@ -21,8 +21,8 @@ use std::rc::Rc; use std::thread; use std::time::Duration; use subscriber::FutureServiceExt as SubscriberExt; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; use tarpc::util::{FirstSocketAddr, Message, Never}; use tokio_core::reactor; @@ -61,7 +61,7 @@ impl Subscriber { fn listen(id: u32, handle: &reactor::Handle, options: server::Options) - -> server::future::Handle { + -> server::Handle { let (server_handle, server) = Subscriber { id: id } .listen("localhost:0".first_socket_addr(), handle, options) .unwrap(); diff --git a/examples/readme_errors.rs b/examples/readme_errors.rs index dc3bf19b..36e3b44e 100644 --- a/examples/readme_errors.rs +++ b/examples/readme_errors.rs @@ -17,8 +17,8 @@ use std::error::Error; use std::fmt; use std::sync::mpsc; use std::thread; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt; +use tarpc::sync::{client, server}; +use tarpc::sync::client::ClientExt; service! { rpc hello(name: String) -> String | NoNameGiven; diff --git a/examples/readme_futures.rs b/examples/readme_futures.rs index 8e68ffd6..5515c9e8 100644 --- a/examples/readme_futures.rs +++ b/examples/readme_futures.rs @@ -12,8 +12,8 @@ extern crate tarpc; extern crate tokio_core; use futures::Future; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; diff --git a/examples/readme_sync.rs b/examples/readme_sync.rs index 44b3591b..9c2b9fa7 100644 --- a/examples/readme_sync.rs +++ b/examples/readme_sync.rs @@ -14,8 +14,8 @@ extern crate tokio_core; use std::sync::mpsc; use std::thread; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt; +use tarpc::sync::{client, server}; +use tarpc::sync::client::ClientExt; use tarpc::util::Never; service! { diff --git a/examples/server_calling_server.rs b/examples/server_calling_server.rs index 256e223c..1e227864 100644 --- a/examples/server_calling_server.rs +++ b/examples/server_calling_server.rs @@ -15,8 +15,8 @@ extern crate tokio_core; use add::{FutureService as AddFutureService, FutureServiceExt as AddExt}; use double::{FutureService as DoubleFutureService, FutureServiceExt as DoubleExt}; use futures::{BoxFuture, Future, Stream}; -use tarpc::{client, server}; -use tarpc::client::future::ClientExt as Fc; +use tarpc::future::{client, server}; +use tarpc::future::client::ClientExt as Fc; use tarpc::util::{FirstSocketAddr, Message, Never}; use tokio_core::reactor; diff --git a/examples/sync_server_calling_server.rs b/examples/sync_server_calling_server.rs index 6441b1fb..59c13fe3 100644 --- a/examples/sync_server_calling_server.rs +++ b/examples/sync_server_calling_server.rs @@ -16,8 +16,8 @@ use add::{SyncService as AddSyncService, SyncServiceExt as AddExt}; use double::{SyncService as DoubleSyncService, SyncServiceExt as DoubleExt}; use std::sync::mpsc; use std::thread; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt as Fc; +use tarpc::sync::{client, server}; +use tarpc::sync::client::ClientExt as Fc; use tarpc::util::{FirstSocketAddr, Message, Never}; pub mod add { @@ -69,7 +69,8 @@ fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { let handle = AddServer.listen("localhost:0".first_socket_addr(), - server::Options::default()).unwrap(); + server::Options::default()) + .unwrap(); tx.send(handle.addr()).unwrap(); handle.run(); }); @@ -80,7 +81,8 @@ fn main() { thread::spawn(move || { let add_client = add::SyncClient::connect(add, client::Options::default()).unwrap(); let handle = DoubleServer::new(add_client) - .listen("localhost:0".first_socket_addr(), server::Options::default()) + .listen("localhost:0".first_socket_addr(), + server::Options::default()) .unwrap(); tx.send(handle.addr()).unwrap(); handle.run(); diff --git a/examples/throughput.rs b/examples/throughput.rs index fd5e26f7..0c79b273 100644 --- a/examples/throughput.rs +++ b/examples/throughput.rs @@ -21,8 +21,8 @@ use std::sync::Arc; use std::sync::mpsc; use std::thread; use std::time; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt; +use tarpc::future::server; +use tarpc::sync::client::{self, ClientExt}; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; diff --git a/examples/two_clients.rs b/examples/two_clients.rs index df0eb113..1ef1000d 100644 --- a/examples/two_clients.rs +++ b/examples/two_clients.rs @@ -19,8 +19,9 @@ use bar::FutureServiceExt as BarExt; use baz::FutureServiceExt as BazExt; use std::sync::mpsc; use std::thread; -use tarpc::{client, server}; -use tarpc::client::sync::ClientExt; +use tarpc::future::server; +use tarpc::sync::client; +use tarpc::sync::client::ClientExt; use tarpc::util::{FirstSocketAddr, Never}; use tokio_core::reactor; diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index b0e9fd3f..00000000 --- a/src/client.rs +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#[cfg(feature = "tls")] -use self::tls::*; -use {WireError, bincode}; -use tokio_core::reactor; - -type WireResponse = Result>, bincode::Error>; - -/// TLS-specific functionality -#[cfg(feature = "tls")] -pub mod tls { - use native_tls::{Error, TlsConnector}; - - /// TLS context for client - pub struct Context { - /// Domain to connect to - pub domain: String, - /// TLS connector - pub tls_connector: TlsConnector, - } - - impl Context { - /// Try to construct a new `Context`. - /// - /// The provided domain will be used for both - /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname - /// validation. - pub fn new>(domain: S) -> Result { - Ok(Context { - domain: domain.into(), - tls_connector: TlsConnector::builder()?.build()?, - }) - } - - /// Construct a new `Context` using the provided domain and `TlsConnector` - /// - /// The domain will be used for both - /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname - /// validation. - pub fn from_connector>(domain: S, tls_connector: TlsConnector) -> Self { - Context { - domain: domain.into(), - tls_connector: tls_connector, - } - } - } -} - -/// Additional options to configure how the client connects and operates. -#[derive(Default)] -pub struct Options { - reactor: Option, - #[cfg(feature = "tls")] - tls_ctx: Option, -} - -impl Options { - /// Drive using the given reactor handle. Only used by `FutureClient`s. - pub fn handle(mut self, handle: reactor::Handle) -> Self { - self.reactor = Some(Reactor::Handle(handle)); - self - } - - /// Drive using the given reactor remote. Only used by `FutureClient`s. - pub fn remote(mut self, remote: reactor::Remote) -> Self { - self.reactor = Some(Reactor::Remote(remote)); - self - } - - /// Connect using the given `Context` - #[cfg(feature = "tls")] - pub fn tls(mut self, tls_ctx: Context) -> Self { - self.tls_ctx = Some(tls_ctx); - self - } -} - -enum Reactor { - Handle(reactor::Handle), - Remote(reactor::Remote), -} - -/// Exposes a trait for connecting asynchronously to servers. -pub mod future { - use super::{Options, Reactor, WireResponse}; - use {REMOTE, WireError}; - #[cfg(feature = "tls")] - use errors::native_to_io; - use futures::{self, Future, future}; - use protocol::Proto; - use serde::{Deserialize, Serialize}; - use std::fmt; - use std::io; - use std::net::SocketAddr; - use stream_type::StreamType; - use tokio_core::net::TcpStream; - use tokio_core::reactor; - use tokio_proto::BindClient as ProtoBindClient; - use tokio_proto::multiplex::Multiplex; - use tokio_service::Service; - #[cfg(feature = "tls")] - use tokio_tls::TlsConnectorExt; - - #[doc(hidden)] - pub struct Client - where Req: Serialize + 'static, - Resp: Deserialize + 'static, - E: Deserialize + 'static - { - inner: BindClient, - } - - impl Clone for Client - where Req: Serialize + 'static, - Resp: Deserialize + 'static, - E: Deserialize + 'static - { - fn clone(&self) -> Self { - Client { inner: self.inner.clone() } - } - } - - impl Service for Client - where Req: Serialize + Sync + Send + 'static, - Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static - { - type Request = Req; - type Response = Resp; - type Error = ::Error; - type Future = ResponseFuture; - - fn call(&self, request: Self::Request) -> Self::Future { - fn identity(t: T) -> T { - t - } - self.inner - .call(request) - .map(Self::map_err as _) - .map_err(::Error::from as _) - .and_then(identity as _) - } - } - - impl Client - where Req: Serialize + 'static, - Resp: Deserialize + 'static, - E: Deserialize + 'static - { - fn new(inner: BindClient) -> Self - where Req: Serialize + Sync + Send + 'static, - Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static - { - Client { inner: inner } - } - - fn map_err(resp: WireResponse) -> Result> { - resp.map(|r| r.map_err(::Error::from)) - .map_err(::Error::ClientSerialize) - .and_then(|r| r) - } - } - - impl fmt::Debug for Client - where Req: Serialize + 'static, - Resp: Deserialize + 'static, - E: Deserialize + 'static - { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "Client {{ .. }}") - } - } - - /// Extension methods for clients. - pub trait ClientExt: Sized { - /// The type of the future returned when calling `connect`. - type ConnectFut: Future; - - /// Connects to a server located at the given address, using the given options. - fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut; - } - - /// A future that resolves to a `Client` or an `io::Error`. - pub type ConnectFuture = - futures::Flatten>>, - fn(futures::Canceled) -> io::Error>>; - - impl ClientExt for Client - where Req: Serialize + Sync + Send + 'static, - Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static - { - type ConnectFut = ConnectFuture; - - fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut { - // we need to do this for tls because we need to avoid moving the entire `Options` - // struct into the `setup` closure, since `Reactor` is not `Send`. - #[cfg(feature = "tls")] - let mut options = options; - #[cfg(feature = "tls")] - let tls_ctx = options.tls_ctx.take(); - - let connect = move |handle: &reactor::Handle| { - let handle2 = handle.clone(); - TcpStream::connect(&addr, handle) - .and_then(move |socket| { - #[cfg(feature = "tls")] - match tls_ctx { - Some(tls_ctx) => { - future::Either::A(tls_ctx.tls_connector - .connect_async(&tls_ctx.domain, socket) - .map(StreamType::Tls) - .map_err(native_to_io)) - } - None => future::Either::B(future::ok(StreamType::Tcp(socket))), - } - #[cfg(not(feature = "tls"))] - future::ok(StreamType::Tcp(socket)) - }) - .map(move |tcp| Client::new(Proto::new().bind_client(&handle2, tcp))) - }; - let (tx, rx) = futures::oneshot(); - let setup = move |handle: &reactor::Handle| { - connect(handle).then(move |result| { - tx.complete(result); - Ok(()) - }) - }; - - match options.reactor { - Some(Reactor::Handle(handle)) => { - handle.spawn(setup(&handle)); - } - Some(Reactor::Remote(remote)) => { - remote.spawn(setup); - } - None => { - REMOTE.spawn(setup); - } - } - fn panic(canceled: futures::Canceled) -> io::Error { - unreachable!(canceled) - } - rx.map_err(panic as _).flatten() - } - } - - type ResponseFuture = - futures::AndThen as Service>::Future, - fn(WireResponse) -> Result>>, - fn(io::Error) -> ::Error>, - Result>, - fn(Result>) -> Result>>; - type BindClient = - >> - as ProtoBindClient>::BindClient; -} - -/// Exposes a trait for connecting synchronously to servers. -pub mod sync { - use futures::{self, Future, Stream}; - use super::Options; - use super::future::{Client as FutureClient, ClientExt as FutureClientExt}; - use serde::{Deserialize, Serialize}; - use std::fmt; - use std::io; - use std::net::ToSocketAddrs; - use std::sync::mpsc; - use std::thread; - use tokio_core::reactor; - use tokio_service::Service; - use util::FirstSocketAddr; - - #[doc(hidden)] - pub struct Client { - request: futures::sync::mpsc::UnboundedSender<(Req, mpsc::Sender>>)>, - } - - impl Clone for Client { - fn clone(&self) -> Self { - Client { - request: self.request.clone(), - } - } - } - - impl fmt::Debug for Client { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "Client {{ .. }}") - } - } - - impl Client - where Req: Serialize + Sync + Send + 'static, - Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static - { - /// Drives an RPC call for the given request. - pub fn call(&self, request: Req) -> Result> { - let (tx, rx) = mpsc::channel(); - self.request.send((request, tx)).unwrap(); - rx.recv().unwrap() - } - } - - /// Extension methods for Clients. - pub trait ClientExt: Sized { - /// Connects to a server located at the given address. - fn connect(addr: A, options: Options) -> io::Result where A: ToSocketAddrs; - } - - impl ClientExt for Client - where Req: Serialize + Sync + Send + 'static, - Resp: Deserialize + Sync + Send + 'static, - E: Deserialize + Sync + Send + 'static - { - fn connect(addr: A, _options: Options) -> io::Result - where A: ToSocketAddrs - { - let addr = addr.try_first_socket_addr()?; - let (connect_tx, connect_rx) = mpsc::channel(); - let (request, request_rx) = futures::sync::mpsc::unbounded(); - #[cfg(feature = "tls")] - let tls_ctx = _options.tls_ctx; - thread::spawn(move || { - let mut reactor = match reactor::Core::new() { - Ok(reactor) => reactor, - Err(e) => { - connect_tx.send(Err(e)).unwrap(); - return; - } - }; - let options; - #[cfg(feature = "tls")] - { - let mut opts = Options::default().handle(reactor.handle()); - opts.tls_ctx = tls_ctx; - options = opts; - } - #[cfg(not(feature = "tls"))] - { - options = Options::default().handle(reactor.handle()); - } - let client = match reactor.run(FutureClient::connect(addr, options)) { - Ok(client) => { - connect_tx.send(Ok(())).unwrap(); - client - } - Err(e) => { - connect_tx.send(Err(e)).unwrap(); - return; - } - }; - let handle = reactor.handle(); - let requests = request_rx.for_each(|(request, response_tx): (_, mpsc::Sender<_>)| { - handle.spawn(client.call(request) - .then(move |response| { - Ok(response_tx.send(response).unwrap()) - })); - Ok(()) - }); - reactor.run(requests).unwrap(); - }); - connect_rx.recv().unwrap()?; - Ok(Client { request }) - } - } -} diff --git a/src/errors.rs b/src/errors.rs index 1916f744..1bc15129 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,16 +12,16 @@ use std::error::Error as StdError; pub enum Error { /// Any IO error. Io(io::Error), - /// Error serializing the client request or deserializing the server response. + /// Error deserializing the server response. /// /// Typically this indicates a faulty implementation of `serde::Serialize` or /// `serde::Deserialize`. - ClientSerialize(::bincode::Error), - /// Error serializing the server response or deserializing the client request. + ResponseDeserialize(::bincode::Error), + /// Error deserializing the client request. /// /// Typically this indicates a faulty implementation of `serde::Serialize` or /// `serde::Deserialize`. - ServerSerialize(String), + RequestDeserialize(String), /// The server was unable to reply to the rpc for some reason. /// /// This is a service-specific error. Its type is individually specified in the @@ -32,8 +32,8 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Error::ClientSerialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), - Error::ServerSerialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), + Error::ResponseDeserialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), + Error::RequestDeserialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), Error::App(ref e) => fmt::Display::fmt(e, f), Error::Io(ref e) => fmt::Display::fmt(e, f), } @@ -43,12 +43,8 @@ impl fmt::Display for Er impl StdError for Error { fn description(&self) -> &str { match *self { - Error::ClientSerialize(_) => { - "The client failed to serialize the request or deserialize the response." - } - Error::ServerSerialize(_) => { - "The server failed to serialize the response or deserialize the request." - } + Error::ResponseDeserialize(_) => "The client failed to deserialize the response.", + Error::RequestDeserialize(_) => "The server failed to deserialize the request.", Error::App(ref e) => e.description(), Error::Io(ref e) => e.description(), } @@ -56,8 +52,8 @@ impl StdError for Error< fn cause(&self) -> Option<&StdError> { match *self { - Error::ClientSerialize(ref e) => e.cause(), - Error::ServerSerialize(_) | + Error::ResponseDeserialize(ref e) => e.cause(), + Error::RequestDeserialize(_) | Error::App(_) => None, Error::Io(ref e) => e.cause(), } @@ -73,7 +69,7 @@ impl From for Error { impl From> for Error { fn from(err: WireError) -> Self { match err { - WireError::ServerSerialize(s) => Error::ServerSerialize(s), + WireError::RequestDeserialize(s) => Error::RequestDeserialize(s), WireError::App(e) => Error::App(e), } } @@ -83,8 +79,8 @@ impl From> for Error { #[doc(hidden)] #[derive(Deserialize, Serialize, Clone, Debug)] pub enum WireError { - /// Error in serializing the server response or deserializing the client request. - ServerSerialize(String), + /// Server-side error in deserializing the client request. + RequestDeserialize(String), /// The server was unable to reply to the rpc for some reason. App(E), } diff --git a/src/future/client.rs b/src/future/client.rs new file mode 100644 index 00000000..1f5e7915 --- /dev/null +++ b/src/future/client.rs @@ -0,0 +1,219 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +use {REMOTE, bincode}; +use future::server::Response; +use futures::{self, Future, future}; +use protocol::Proto; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::io; +use std::net::SocketAddr; +use stream_type::StreamType; +use tokio_core::net::TcpStream; +use tokio_core::reactor; +use tokio_proto::BindClient as ProtoBindClient; +use tokio_proto::multiplex::ClientService; +use tokio_service::Service; + +cfg_if! { + if #[cfg(feature = "tls")] { + use errors::native_to_io; + use tls::client::Context; + use tokio_tls::TlsConnectorExt; + } else {} +} + +/// Additional options to configure how the client connects and operates. +#[derive(Default)] +pub struct Options { + reactor: Option, + #[cfg(feature = "tls")] + tls_ctx: Option, +} + +impl Options { + /// Drive using the given reactor handle. Only used by `FutureClient`s. + pub fn handle(mut self, handle: reactor::Handle) -> Self { + self.reactor = Some(Reactor::Handle(handle)); + self + } + + /// Drive using the given reactor remote. Only used by `FutureClient`s. + pub fn remote(mut self, remote: reactor::Remote) -> Self { + self.reactor = Some(Reactor::Remote(remote)); + self + } + + /// Connect using the given `Context` + #[cfg(feature = "tls")] + pub fn tls(mut self, tls_ctx: Context) -> Self { + self.tls_ctx = Some(tls_ctx); + self + } +} + +enum Reactor { + Handle(reactor::Handle), + Remote(reactor::Remote), +} + +#[doc(hidden)] +pub struct Client + where Req: Serialize + 'static, + Resp: Deserialize + 'static, + E: Deserialize + 'static +{ + inner: ClientService>>, +} + +impl Clone for Client + where Req: Serialize + 'static, + Resp: Deserialize + 'static, + E: Deserialize + 'static +{ + fn clone(&self) -> Self { + Client { inner: self.inner.clone() } + } +} + +impl Service for Client + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static +{ + type Request = Req; + type Response = Resp; + type Error = ::Error; + type Future = ResponseFuture; + + fn call(&self, request: Self::Request) -> Self::Future { + fn identity(t: T) -> T { + t + } + self.inner + .call(request) + .map(Self::map_err as _) + .map_err(::Error::from as _) + .and_then(identity as _) + } +} + +impl Client + where Req: Serialize + 'static, + Resp: Deserialize + 'static, + E: Deserialize + 'static +{ + fn bind(handle: &reactor::Handle, tcp: StreamType) -> Self + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static + { + let inner = Proto::new().bind_client(&handle, tcp); + Client { inner } + } + + fn map_err(resp: WireResponse) -> Result> { + resp.map(|r| r.map_err(::Error::from)) + .map_err(::Error::ResponseDeserialize) + .and_then(|r| r) + } +} + +impl fmt::Debug for Client + where Req: Serialize + 'static, + Resp: Deserialize + 'static, + E: Deserialize + 'static +{ + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "Client {{ .. }}") + } +} + +/// Extension methods for clients. +pub trait ClientExt: Sized { + /// The type of the future returned when calling `connect`. + type ConnectFut: Future; + + /// Connects to a server located at the given address, using the given options. + fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut; +} + +/// A future that resolves to a `Client` or an `io::Error`. +pub type ConnectFuture = + futures::Flatten>>, + fn(futures::Canceled) -> io::Error>>; + +impl ClientExt for Client + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static +{ + type ConnectFut = ConnectFuture; + + fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut { + // we need to do this for tls because we need to avoid moving the entire `Options` + // struct into the `setup` closure, since `Reactor` is not `Send`. + #[cfg(feature = "tls")] + let mut options = options; + #[cfg(feature = "tls")] + let tls_ctx = options.tls_ctx.take(); + + let connect = move |handle: &reactor::Handle| { + let handle2 = handle.clone(); + TcpStream::connect(&addr, handle) + .and_then(move |socket| { + // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the + // ServerProto impl + #[cfg(feature = "tls")] + match tls_ctx { + Some(tls_ctx) => { + future::Either::A(tls_ctx.tls_connector + .connect_async(&tls_ctx.domain, socket) + .map(StreamType::Tls) + .map_err(native_to_io)) + } + None => future::Either::B(future::ok(StreamType::Tcp(socket))), + } + #[cfg(not(feature = "tls"))] + future::ok(StreamType::Tcp(socket)) + }) + .map(move |tcp| Client::bind(&handle2, tcp)) + }; + let (tx, rx) = futures::oneshot(); + let setup = move |handle: &reactor::Handle| { + connect(handle).then(move |result| { + tx.complete(result); + Ok(()) + }) + }; + + match options.reactor { + Some(Reactor::Handle(handle)) => { + handle.spawn(setup(&handle)); + } + Some(Reactor::Remote(remote)) => { + remote.spawn(setup); + } + None => { + REMOTE.spawn(setup); + } + } + fn panic(canceled: futures::Canceled) -> io::Error { + unreachable!(canceled) + } + rx.map_err(panic as _).flatten() + } +} + +type ResponseFuture = + futures::AndThen>> as Service>::Future, + fn(WireResponse) -> Result>>, + fn(io::Error) -> ::Error>, + Result>, + fn(Result>) -> Result>>; + +type WireResponse = Result, bincode::Error>; diff --git a/src/future/mod.rs b/src/future/mod.rs new file mode 100644 index 00000000..79011a35 --- /dev/null +++ b/src/future/mod.rs @@ -0,0 +1,4 @@ +/// Provides the base client stubs used by the service macro. +pub mod client; +/// Provides the base server boilerplate used by service implementations. +pub mod server; diff --git a/src/server.rs b/src/future/server.rs similarity index 82% rename from src/server.rs rename to src/future/server.rs index d6f8ec3c..c07269e6 100644 --- a/src/server.rs +++ b/src/future/server.rs @@ -3,12 +3,11 @@ // Licensed under the MIT License, . // This file may not be copied, modified, or distributed except according to those terms. -use bincode; +use {bincode, net2}; use errors::WireError; use futures::{Future, Poll, Stream, future as futures, stream}; use futures::sync::{mpsc, oneshot}; use futures::unsync; -use net2; use protocol::Proto; use serde::{Deserialize, Serialize}; use std::cell::Cell; @@ -30,6 +29,47 @@ cfg_if! { } else {} } +/// A handle to a bound server. +#[derive(Clone)] +pub struct Handle { + addr: SocketAddr, + shutdown: Shutdown, +} + +impl Handle { + #[doc(hidden)] + pub fn listen(new_service: S, + addr: SocketAddr, + handle: &reactor::Handle, + options: Options) + -> io::Result<(Self, Listen)> + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static + { + let (addr, shutdown, server) = + listen_with(new_service, addr, handle, Acceptor::from(options))?; + Ok((Handle { + addr: addr, + shutdown: shutdown, + }, + server)) + } + + /// Returns a hook for shutting down the server. + pub fn shutdown(&self) -> &Shutdown { + &self.shutdown + } + + /// The socket address the server is bound to. + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + enum Acceptor { Tcp, #[cfg(feature = "tls")] @@ -46,6 +86,7 @@ type Accept = futures::Either; impl Acceptor { + // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the ServerProto impl #[cfg(feature = "tls")] fn accept(&self, socket: TcpStream) -> Accept { match *self { @@ -152,7 +193,7 @@ impl Shutdown { /// connections are closed, it initates total shutdown. /// /// This fn will not return until the server is completely shut down. - pub fn shutdown(self) -> ShutdownFuture { + pub fn shutdown(&self) -> ShutdownFuture { let (tx, rx) = oneshot::channel(); let inner = if let Err(_) = self.tx.send(tx) { trace!("Server already initiated shutdown."); @@ -229,109 +270,6 @@ impl NewService for ConnectionTrackingNewService { } } -/// Future-specific server utilities. -pub mod future { - pub use super::*; - - /// A handle to a bound server. - #[derive(Clone)] - pub struct Handle { - addr: SocketAddr, - shutdown: Shutdown, - } - - impl Handle { - #[doc(hidden)] - pub fn listen(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - options: Options) - -> io::Result<(Self, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static - { - let (addr, shutdown, server) = - listen_with(new_service, addr, handle, Acceptor::from(options))?; - Ok((Handle { - addr: addr, - shutdown: shutdown, - }, - server)) - } - - /// Returns a hook for shutting down the server. - pub fn shutdown(&self) -> Shutdown { - self.shutdown.clone() - } - - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.addr - } - } -} - -/// Sync-specific server utilities. -pub mod sync { - pub use super::*; - - /// A handle to a bound server. Must be run to start serving requests. - #[must_use = "A server does nothing until `run` is called."] - pub struct Handle { - reactor: reactor::Core, - handle: future::Handle, - server: Box>, - } - - impl Handle { - #[doc(hidden)] - pub fn listen(new_service: S, - addr: SocketAddr, - options: Options) - -> io::Result - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: Deserialize + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static - { - let reactor = reactor::Core::new()?; - let (handle, server) = - future::Handle::listen(new_service, addr, &reactor.handle(), options)?; - let server = Box::new(server); - Ok(Handle { - reactor: reactor, - handle: handle, - server: server, - }) - } - - /// Runs the server on the current thread, blocking indefinitely. - pub fn run(mut self) { - trace!("Running..."); - match self.reactor.run(self.server) { - Ok(()) => debug!("Server successfully shutdown."), - Err(()) => debug!("Server shutdown due to error."), - } - } - - /// Returns a hook for shutting down the server. - pub fn shutdown(&self) -> Shutdown { - self.handle.shutdown().clone() - } - - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.handle.addr() - } - } -} - struct ShutdownSetter { shutdown: Rc>>>, } @@ -661,20 +599,28 @@ impl Fn<(I,)> for Bind fn listener(addr: &SocketAddr, handle: &reactor::Handle) -> io::Result { const PENDING_CONNECTION_BACKLOG: i32 = 1024; - #[cfg(unix)] - use net2::unix::UnixTcpBuilderExt; let builder = match *addr { SocketAddr::V4(_) => net2::TcpBuilder::new_v4(), SocketAddr::V6(_) => net2::TcpBuilder::new_v6(), }?; - + configure_tcp(&builder)?; builder.reuse_address(true)?; - - #[cfg(unix)] - builder.reuse_port(true)?; - builder.bind(addr)? .listen(PENDING_CONNECTION_BACKLOG) .and_then(|l| TcpListener::from_listener(l, addr, handle)) } + +#[cfg(unix)] +fn configure_tcp(tcp: &net2::TcpBuilder) -> io::Result<()> { + use net2::unix::UnixTcpBuilderExt; + + tcp.reuse_port(true)?; + + Ok(()) +} + +#[cfg(windows)] +fn configure_tcp(_tcp: &net2::TcpBuilder) -> io::Result<()> { + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index dbed3267..6fe61b7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,8 +34,8 @@ //! extern crate tarpc; //! extern crate tokio_core; //! -//! use tarpc::{client, server}; -//! use tarpc::client::sync::ClientExt; +//! use tarpc::sync::{client, server}; +//! use tarpc::sync::client::ClientExt; //! use tarpc::util::Never; //! use tokio_core::reactor; //! use std::sync::mpsc; @@ -63,22 +63,23 @@ //! handle.run(); //! }); //! let addr = rx.recv().unwrap(); -//! let mut client = SyncClient::connect(addr, client::Options::default()).unwrap(); +//! let client = SyncClient::connect(addr, client::Options::default()).unwrap(); //! println!("{}", client.hello("Mom".to_string()).unwrap()); //! } //! ``` //! //! Example usage with TLS: //! -//! ```ignore +//! ```no-run //! #![feature(plugin)] //! #![plugin(tarpc_plugins)] //! //! #[macro_use] //! extern crate tarpc; //! -//! use tarpc::{client, server}; -//! use tarpc::client::sync::ClientExt; +//! use tarpc::sync::{client, server}; +//! use tarpc::sync::client::ClientExt; +//! use tarpc::tls; //! use tarpc::util::Never; //! use tarpc::native_tls::{TlsAcceptor, Pkcs12}; //! @@ -105,9 +106,9 @@ //! let addr = "localhost:10000"; //! let acceptor = get_acceptor(); //! let _server = HelloServer.listen(addr, server::Options::default().tls(acceptor)); -//! let mut client = SyncClient::connect(addr, +//! let client = SyncClient::connect(addr, //! client::Options::default() -//! .tls(client::tls::Context::new("foobar.com").unwrap())) +//! .tls(tls::client::Context::new("foobar.com").unwrap())) //! .unwrap(); //! println!("{}", client.hello("Mom".to_string()).unwrap()); //! } @@ -152,10 +153,13 @@ pub mod util; /// Provides the macro used for constructing rpc services and client stubs. #[macro_use] mod macros; -/// Provides the base client stubs used by the service macro. -pub mod client; -/// Provides the base server boilerplate used by service implementations. -pub mod server; +/// Synchronous version of the tarpc API +pub mod sync; +/// Futures-based version of the tarpc API. +pub mod future; +/// TLS-specific functionality. +#[cfg(feature = "tls")] +pub mod tls; /// Provides implementations of `ClientProto` and `ServerProto` that implement the tarpc protocol. /// The tarpc protocol is a length-delimited, bincode-serialized payload. mod protocol; @@ -193,7 +197,7 @@ cfg_if! { extern crate tokio_tls; extern crate native_tls as native_tls_inner; - /// Re-exported TLS-related types + /// Re-exported TLS-related types from the `native_tls` crate. pub mod native_tls { pub use native_tls_inner::{Error, Pkcs12, TlsAcceptor, TlsConnector}; } diff --git a/src/macros.rs b/src/macros.rs index 6679ae8b..27aed6d4 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -283,7 +283,6 @@ macro_rules! service { #[doc(hidden)] #[allow(non_camel_case_types, unused)] - #[derive(Debug)] pub enum tarpc_service_Request__ { NotIrrefutable(()), $( @@ -296,7 +295,6 @@ macro_rules! service { #[doc(hidden)] #[allow(non_camel_case_types, unused)] - #[derive(Debug)] pub enum tarpc_service_Response__ { NotIrrefutable(()), $( @@ -351,7 +349,7 @@ macro_rules! service { #[allow(non_camel_case_types)] type tarpc_service_Future__ = - $crate::futures::Finished<$crate::server::Response, ::std::io::Error>; @@ -368,7 +366,7 @@ macro_rules! service { } impl $crate::futures::Future for tarpc_service_FutureReply__ { - type Item = $crate::server::Response; type Error = ::std::io::Error; @@ -399,7 +397,7 @@ macro_rules! service { { type Request = ::std::result::Result; - type Response = $crate::server::Response; type Error = ::std::io::Error; type Future = tarpc_service_FutureReply__; @@ -411,7 +409,7 @@ macro_rules! service { return tarpc_service_FutureReply__::DeserializeError( $crate::futures::finished( ::std::result::Result::Err( - $crate::WireError::ServerSerialize( + $crate::WireError::RequestDeserialize( ::std::string::ToString::to_string( &tarpc_service_deserialize_err__))))); } @@ -466,7 +464,7 @@ macro_rules! service { pub struct Listen where S: FutureService, { - inner: $crate::server::Listen, + inner: $crate::future::server::Listen, tarpc_service_Request__, tarpc_service_Response__, tarpc_service_Error__>, @@ -494,10 +492,10 @@ macro_rules! service { fn listen(self, addr: ::std::net::SocketAddr, handle: &$crate::tokio_core::reactor::Handle, - options: $crate::server::Options) - -> ::std::io::Result<($crate::server::future::Handle, Listen)> + options: $crate::future::server::Options) + -> ::std::io::Result<($crate::future::server::Handle, Listen)> { - $crate::server::future::Handle::listen(tarpc_service_AsyncServer__(self), + $crate::future::server::Handle::listen(tarpc_service_AsyncServer__(self), addr, handle, options) @@ -525,8 +523,8 @@ macro_rules! service { /// Spawns the service, binding to the given address and returning the server handle. /// /// To actually run the server, call `run` on the returned handle. - fn listen(self, addr: A, options: $crate::server::Options) - -> ::std::io::Result<$crate::server::sync::Handle> + fn listen(self, addr: A, options: $crate::sync::server::Options) + -> ::std::io::Result<$crate::sync::server::Handle> where A: ::std::net::ToSocketAddrs { let tarpc_service__ = tarpc_service_AsyncServer__(SyncServer__ { @@ -536,7 +534,7 @@ macro_rules! service { let tarpc_service_addr__ = $crate::util::FirstSocketAddr::try_first_socket_addr(&addr)?; - return $crate::server::sync::Handle::listen(tarpc_service__, + return $crate::sync::server::Handle::listen(tarpc_service__, tarpc_service_addr__, options); @@ -595,12 +593,13 @@ macro_rules! service { inner: tarpc_service_SyncClient__, } - impl $crate::client::sync::ClientExt for SyncClient { - fn connect(addr_: A, options_: $crate::client::Options) -> ::std::io::Result + impl $crate::sync::client::ClientExt for SyncClient { + fn connect(addr_: A, options_: $crate::sync::client::Options) + -> ::std::io::Result where A: ::std::net::ToSocketAddrs, { let client_ = ::connect(addr_, options_)?; + as $crate::sync::client::ClientExt>::connect(addr_, options_)?; ::std::result::Result::Ok(SyncClient { inner: client_, }) @@ -642,11 +641,11 @@ macro_rules! service { unreachable!() } } - $crate::Error::ServerSerialize(tarpc_service_err__) => { - $crate::Error::ServerSerialize(tarpc_service_err__) + $crate::Error::RequestDeserialize(tarpc_service_err__) => { + $crate::Error::RequestDeserialize(tarpc_service_err__) } - $crate::Error::ClientSerialize(tarpc_service_err__) => { - $crate::Error::ClientSerialize(tarpc_service_err__) + $crate::Error::ResponseDeserialize(tarpc_service_err__) => { + $crate::Error::ResponseDeserialize(tarpc_service_err__) } $crate::Error::Io(tarpc_service_error__) => { $crate::Error::Io(tarpc_service_error__) @@ -661,27 +660,27 @@ macro_rules! service { #[allow(non_camel_case_types)] type tarpc_service_FutureClient__ = - $crate::client::future::Client; #[allow(non_camel_case_types)] type tarpc_service_SyncClient__ = - $crate::client::sync::Client; #[allow(non_camel_case_types)] - /// Implementation detail: Pending connection. - pub struct tarpc_service_ConnectFuture__ { - inner: $crate::futures::Map<$crate::client::future::ConnectFuture< + /// A future representing a client connecting to a server. + pub struct Connect { + inner: $crate::futures::Map<$crate::future::client::ConnectFuture< tarpc_service_Request__, tarpc_service_Response__, tarpc_service_Error__>, fn(tarpc_service_FutureClient__) -> T>, } - impl $crate::futures::Future for tarpc_service_ConnectFuture__ { + impl $crate::futures::Future for Connect { type Item = T; type Error = ::std::io::Error; @@ -695,18 +694,18 @@ macro_rules! service { /// The client stub that makes RPC calls to the server. Exposes a Future interface. pub struct FutureClient(tarpc_service_FutureClient__); - impl<'a> $crate::client::future::ClientExt for FutureClient { - type ConnectFut = tarpc_service_ConnectFuture__; + impl<'a> $crate::future::client::ClientExt for FutureClient { + type ConnectFut = Connect; fn connect(tarpc_service_addr__: ::std::net::SocketAddr, - tarpc_service_options__: $crate::client::Options) + tarpc_service_options__: $crate::future::client::Options) -> Self::ConnectFut { let client = ::connect(tarpc_service_addr__, + as $crate::future::client::ClientExt>::connect(tarpc_service_addr__, tarpc_service_options__); - tarpc_service_ConnectFuture__ { + Connect { inner: $crate::futures::Future::map(client, FutureClient) } } @@ -754,11 +753,11 @@ macro_rules! service { unreachable!() } } - $crate::Error::ServerSerialize(tarpc_service_err__) => { - $crate::Error::ServerSerialize(tarpc_service_err__) + $crate::Error::RequestDeserialize(tarpc_service_err__) => { + $crate::Error::RequestDeserialize(tarpc_service_err__) } - $crate::Error::ClientSerialize(tarpc_service_err__) => { - $crate::Error::ClientSerialize(tarpc_service_err__) + $crate::Error::ResponseDeserialize(tarpc_service_err__) => { + $crate::Error::ResponseDeserialize(tarpc_service_err__) } $crate::Error::Io(tarpc_service_error__) => { $crate::Error::Io(tarpc_service_error__) @@ -801,7 +800,7 @@ mod syntax_test { #[cfg(test)] mod functional_test { - use {client, server}; + use {sync, future}; use futures::{Future, failed}; use std::io; use std::net::SocketAddr; @@ -825,14 +824,21 @@ mod functional_test { if #[cfg(feature = "tls")] { const DOMAIN: &'static str = "foobar.com"; - use client::tls::Context; + use tls::client::Context; use native_tls::{Pkcs12, TlsAcceptor, TlsConnector}; - fn get_tls_server_options() -> server::Options { + fn get_tls_acceptor() -> TlsAcceptor { let buf = include_bytes!("../test/identity.p12"); let pkcs12 = unwrap!(Pkcs12::from_der(buf, "mypass")); - let acceptor = unwrap!(unwrap!(TlsAcceptor::builder(pkcs12)).build()); - server::Options::default().tls(acceptor) + unwrap!(unwrap!(TlsAcceptor::builder(pkcs12)).build()) + } + + fn get_future_tls_server_options() -> future::server::Options { + future::server::Options::default().tls(get_tls_acceptor()) + } + + fn get_sync_tls_server_options() -> sync::server::Options { + sync::server::Options::default().tls(get_tls_acceptor()) } // Making the TlsConnector for testing needs to be OS-dependent just like native-tls. @@ -846,32 +852,47 @@ mod functional_test { use self::security_framework::certificate::SecCertificate; use native_tls_inner::backend::security_framework::TlsConnectorBuilderExt; - fn get_tls_client_options() -> client::Options { + fn get_future_tls_client_options() -> future::client::Options { + future::client::Options::default().tls(get_tls_client_context()) + } + + fn get_sync_tls_client_options() -> sync::client::Options { + sync::client::Options::default().tls(get_tls_client_context()) + } + + fn get_tls_client_context() -> Context { let buf = include_bytes!("../test/root-ca.der"); let cert = unwrap!(SecCertificate::from_der(buf)); let mut connector = unwrap!(TlsConnector::builder()); connector.anchor_certificates(&[cert]); - client::Options::default() - .tls(Context { - domain: DOMAIN.into(), - tls_connector: unwrap!(connector.build()), - }) + Context { + domain: DOMAIN.into(), + tls_connector: unwrap!(connector.build()), + } } } else if #[cfg(all(not(target_os = "macos"), not(windows)))] { use native_tls_inner::backend::openssl::TlsConnectorBuilderExt; - fn get_tls_client_options() -> client::Options { + fn get_sync_tls_client_options() -> sync::client::Options { + sync::client::Options::default() + .tls(get_tls_client_context()) + } + + fn get_future_tls_client_options() -> future::client::Options { + future::client::Options::default() + .tls(get_tls_client_context()) + } + + fn get_tls_client_context() -> Context { let mut connector = unwrap!(TlsConnector::builder()); unwrap!(connector.builder_mut() .builder_mut() .set_ca_file("test/root-ca.pem")); - - client::Options::default() - .tls(Context { - domain: DOMAIN.into(), - tls_connector: unwrap!(connector.build()), - }) + Context { + domain: DOMAIN.into(), + tls_connector: unwrap!(connector.build()), + } } // not implemented for windows or other platforms } else { @@ -882,22 +903,22 @@ mod functional_test { } fn get_sync_client(addr: SocketAddr) -> io::Result - where C: client::sync::ClientExt + where C: sync::client::ClientExt { - C::connect(addr, get_tls_client_options()) + C::connect(addr, get_sync_tls_client_options()) } fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut - where C: client::future::ClientExt + where C: future::client::ClientExt { - C::connect(addr, get_tls_client_options().handle(handle)) + C::connect(addr, get_future_tls_client_options().handle(handle)) } fn start_server_with_sync_client(server: S) - -> io::Result<(SocketAddr, C, server::Shutdown)> - where C: client::sync::ClientExt, S: SyncServiceExt + -> io::Result<(SocketAddr, C, future::server::Shutdown)> + where C: sync::client::ClientExt, S: SyncServiceExt { - let options = get_tls_server_options(); + let options = get_sync_tls_server_options(); let (tx, rx) = ::std::sync::mpsc::channel(); ::std::thread::spawn(move || { let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), @@ -906,31 +927,31 @@ mod functional_test { handle.run(); }); let (addr, shutdown) = rx.recv().unwrap(); - let client = unwrap!(C::connect(addr, get_tls_client_options())); + let client = unwrap!(C::connect(addr, get_sync_tls_client_options())); Ok((addr, client, shutdown)) } fn start_server_with_async_client(server: S) - -> io::Result<(server::future::Handle, reactor::Core, C)> - where C: client::future::ClientExt, S: FutureServiceExt + -> io::Result<(future::server::Handle, reactor::Core, C)> + where C: future::client::ClientExt, S: FutureServiceExt { let mut reactor = reactor::Core::new()?; - let server_options = get_tls_server_options(); + let server_options = get_future_tls_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server_options)?; reactor.handle().spawn(server); - let client_options = get_tls_client_options().handle(reactor.handle()); + let client_options = get_future_tls_client_options().handle(reactor.handle()); let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); Ok((handle, reactor, client)) } fn return_server(server: S) - -> io::Result<(server::future::Handle, reactor::Core, Listen)> + -> io::Result<(future::server::Handle, reactor::Core, Listen)> where S: FutureServiceExt { let reactor = reactor::Core::new()?; - let server_options = get_tls_server_options(); + let server_options = get_future_tls_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server_options)?; @@ -938,45 +959,53 @@ mod functional_test { } fn start_err_server_with_async_client(server: S) - -> io::Result<(server::future::Handle, reactor::Core, C)> - where C: client::future::ClientExt, S: error_service::FutureServiceExt + -> io::Result<(future::server::Handle, reactor::Core, C)> + where C: future::client::ClientExt, S: error_service::FutureServiceExt { let mut reactor = reactor::Core::new()?; - let server_options = get_tls_server_options(); + let server_options = get_future_tls_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), server_options)?; reactor.handle().spawn(server); - let client_options = get_tls_client_options().handle(reactor.handle()); + let client_options = get_future_tls_client_options().handle(reactor.handle()); let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); Ok((handle, reactor, client)) } } else { - fn get_server_options() -> server::Options { - server::Options::default() + fn get_future_server_options() -> future::server::Options { + future::server::Options::default() + } + + fn get_sync_server_options() -> sync::server::Options { + sync::server::Options::default() } - fn get_client_options() -> client::Options { - client::Options::default() + fn get_future_client_options() -> future::client::Options { + future::client::Options::default() + } + + fn get_sync_client_options() -> sync::client::Options { + sync::client::Options::default() } fn get_sync_client(addr: SocketAddr) -> io::Result - where C: client::sync::ClientExt + where C: sync::client::ClientExt { - C::connect(addr, get_client_options()) + C::connect(addr, get_sync_client_options()) } fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut - where C: client::future::ClientExt + where C: future::client::ClientExt { - C::connect(addr, get_client_options().handle(handle)) + C::connect(addr, get_future_client_options().handle(handle)) } fn start_server_with_sync_client(server: S) - -> io::Result<(SocketAddr, C, server::Shutdown)> - where C: client::sync::ClientExt, S: SyncServiceExt + -> io::Result<(SocketAddr, C, future::server::Shutdown)> + where C: sync::client::ClientExt, S: SyncServiceExt { - let options = get_server_options(); + let options = get_sync_server_options(); let (tx, rx) = ::std::sync::mpsc::channel(); ::std::thread::spawn(move || { let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), options)); @@ -989,25 +1018,26 @@ mod functional_test { } fn start_server_with_async_client(server: S) - -> io::Result<(server::future::Handle, reactor::Core, C)> - where C: client::future::ClientExt, S: FutureServiceExt + -> io::Result<(future::server::Handle, reactor::Core, C)> + where C: future::client::ClientExt, S: FutureServiceExt { let mut reactor = reactor::Core::new()?; - let options = get_server_options(); + let options = get_future_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), options)?; reactor.handle().spawn(server); - let client = unwrap!(reactor.run(C::connect(handle.addr(), get_client_options()))); + let client = unwrap!(reactor.run(C::connect(handle.addr(), + get_future_client_options()))); Ok((handle, reactor, client)) } fn return_server(server: S) - -> io::Result<(server::future::Handle, reactor::Core, Listen)> + -> io::Result<(future::server::Handle, reactor::Core, Listen)> where S: FutureServiceExt { let reactor = reactor::Core::new()?; - let options = get_server_options(); + let options = get_future_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), options)?; @@ -1015,25 +1045,25 @@ mod functional_test { } fn start_err_server_with_async_client(server: S) - -> io::Result<(server::future::Handle, reactor::Core, C)> - where C: client::future::ClientExt, S: error_service::FutureServiceExt + -> io::Result<(future::server::Handle, reactor::Core, C)> + where C: future::client::ClientExt, S: error_service::FutureServiceExt { let mut reactor = reactor::Core::new()?; - let options = get_server_options(); + let options = get_future_server_options(); let (handle, server) = server.listen("localhost:0".first_socket_addr(), &reactor.handle(), options)?; reactor.handle().spawn(server); - let client = C::connect(handle.addr(), get_client_options()); + let client = C::connect(handle.addr(), get_future_client_options()); let client = unwrap!(reactor.run(client)); Ok((handle, reactor, client)) } } } - - mod sync { - use super::{SyncClient, SyncService, get_sync_client, env_logger, start_server_with_sync_client}; + mod sync_tests { + use super::{SyncClient, SyncService, get_sync_client, env_logger, + start_server_with_sync_client}; use util::Never; #[derive(Clone, Copy)] @@ -1052,18 +1082,18 @@ mod functional_test { fn simple() { let _ = env_logger::init(); let (_, client, _) = unwrap!(start_server_with_sync_client::(Server)); + Server>(Server)); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); } #[test] fn shutdown() { - use futures::Future; + use futures::{Async, Future}; let _ = env_logger::init(); - let (addr, client, shutdown) = - unwrap!(start_server_with_sync_client::(Server)); + let (addr, client, shutdown) = unwrap!(start_server_with_sync_client::(Server)); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); @@ -1082,7 +1112,12 @@ mod functional_test { tx2.send(add).unwrap(); }); rx.recv().unwrap(); + let mut shutdown1 = shutdown.shutdown(); shutdown.shutdown().wait().unwrap(); + // Assert shutdown2 blocks until shutdown is complete. + if let Async::NotReady = shutdown1.poll().unwrap() { + panic!("Shutdown should have completed"); + } // Existing clients are served assert_eq!(5, rx2.recv().unwrap()); @@ -1093,8 +1128,8 @@ mod functional_test { #[test] fn no_shutdown() { let _ = env_logger::init(); - let (addr, client, shutdown) = - unwrap!(start_server_with_sync_client::(Server)); + let (addr, client, shutdown) = unwrap!(start_server_with_sync_client::(Server)); assert_eq!(3, client.add(1, 2).unwrap()); assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); @@ -1113,13 +1148,48 @@ mod functional_test { unwrap!(start_server_with_sync_client::(Server)); match client.foo().err().expect("failed unwrap") { - ::Error::ServerSerialize(_) => {} // good - bad => panic!("Expected Error::ServerSerialize but got {}", bad), + ::Error::RequestDeserialize(_) => {} // good + bad => panic!("Expected Error::RequestDeserialize but got {}", bad), } } } - mod future { + mod bad_serialize { + use sync::{client, server}; + use sync::client::ClientExt; + use serde::{Serialize, Serializer}; + use serde::ser::SerializeSeq; + + #[derive(Deserialize)] + pub struct Bad; + + impl Serialize for Bad { + fn serialize(&self, serializer: S) -> Result + where S: Serializer + { + serializer.serialize_seq(None)?.end() + } + } + + service! { + rpc bad(bad: Bad) | (); + } + + impl SyncService for () { + fn bad(&self, _: Bad) -> Result<(), ()> { + Ok(()) + } + } + + #[test] + fn bad_serialize() { + let handle = ().listen("localhost:0", server::Options::default()).unwrap(); + let client = SyncClient::connect(handle.addr(), client::Options::default()).unwrap(); + client.bad(Bad).err().unwrap(); + } + } + + mod future_tests { use super::{FutureClient, FutureService, env_logger, get_future_client, return_server, start_server_with_async_client}; use futures::{Finished, finished}; @@ -1198,15 +1268,15 @@ mod functional_test { unwrap!(start_server_with_async_client::(Server)); match reactor.run(client.foo()).err().unwrap() { - ::Error::ServerSerialize(_) => {} // good - bad => panic!(r#"Expected Error::ServerSerialize but got "{}""#, bad), + ::Error::RequestDeserialize(_) => {} // good + bad => panic!(r#"Expected Error::RequestDeserialize but got "{}""#, bad), } } #[test] fn reuse_addr() { use util::FirstSocketAddr; - use server; + use future::server; use super::FutureServiceExt; let _ = env_logger::init(); @@ -1221,8 +1291,8 @@ mod functional_test { #[test] fn drop_client() { - use {client, server}; - use client::future::ClientExt; + use future::{client, server}; + use future::client::ClientExt; use util::FirstSocketAddr; use super::{FutureClient, FutureServiceExt}; @@ -1249,9 +1319,9 @@ mod functional_test { #[cfg(feature = "tls")] #[test] fn tcp_and_tls() { - use {client, server}; + use future::{client, server}; use util::FirstSocketAddr; - use client::future::ClientExt; + use future::client::ClientExt; use super::FutureServiceExt; let _ = env_logger::init(); diff --git a/src/sync/client.rs b/src/sync/client.rs new file mode 100644 index 00000000..6aa142cf --- /dev/null +++ b/src/sync/client.rs @@ -0,0 +1,192 @@ + +use future::client::{Client as FutureClient, ClientExt as FutureClientExt, + Options as FutureOptions}; +/// Exposes a trait for connecting synchronously to servers. +use futures::{Future, Stream}; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::io; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::sync::mpsc; +use std::thread; +use tokio_core::reactor; +use tokio_proto::util::client_proxy::{ClientProxy, Receiver, pair}; +use tokio_service::Service; +use util::FirstSocketAddr; +#[cfg(feature = "tls")] +use tls::client::Context; + +#[doc(hidden)] +pub struct Client { + proxy: ClientProxy>, +} + +impl Clone for Client { + fn clone(&self) -> Self { + Client { proxy: self.proxy.clone() } + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "Client {{ .. }}") + } +} + +impl Client + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static +{ + /// Drives an RPC call for the given request. + pub fn call(&self, request: Req) -> Result> { + self.proxy.call(request).wait() + } + +} + +/// Additional options to configure how the client connects and operates. +#[derive(Default)] +pub struct Options { + #[cfg(feature = "tls")] + tls_ctx: Option, +} + +impl Options { + /// Connect using the given `Context` + #[cfg(feature = "tls")] + pub fn tls(mut self, ctx: Context) -> Self { + self.tls_ctx = Some(ctx); + self + } +} + +impl Into for (reactor::Handle, Options) { + #[cfg(feature = "tls")] + fn into(self) -> FutureOptions { + let (handle, options) = self; + let mut opts = FutureOptions::default().handle(handle); + if let Some(tls_ctx) = options.tls_ctx { + opts = opts.tls(tls_ctx); + } + opts + } + + #[cfg(not(feature = "tls"))] + fn into(self) -> FutureOptions { + let (handle, _) = self; + FutureOptions::default().handle(handle) + } +} + +/// Extension methods for Clients. +pub trait ClientExt: Sized { + /// Connects to a server located at the given address. + fn connect(addr: A, options: Options) -> io::Result where A: ToSocketAddrs; +} + +impl ClientExt for Client + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static +{ + fn connect(addr: A, options: Options) -> io::Result + where A: ToSocketAddrs + { + let addr = addr.try_first_socket_addr()?; + let (connect_tx, connect_rx) = mpsc::channel(); + thread::spawn(move || { + match RequestHandler::connect(addr, options) { + Ok((proxy, mut handler)) => { + connect_tx.send(Ok(proxy)).unwrap(); + handler.handle_requests(); + } + Err(e) => connect_tx.send(Err(e)).unwrap(), + } + }); + Ok(connect_rx.recv().unwrap()?) + } +} + +/// Forwards incoming requests of type `Req` +/// with expected response `Result>` +/// to service `S`. +struct RequestHandler { + reactor: reactor::Core, + client: S, + requests: Receiver>, +} + +impl RequestHandler> + where Req: Serialize + Sync + Send + 'static, + Resp: Deserialize + Sync + Send + 'static, + E: Deserialize + Sync + Send + 'static +{ + /// Creates a new `RequestHandler` by connecting a `FutureClient` to the given address + /// using the given options. + fn connect(addr: SocketAddr, options: Options) + -> io::Result<(Client, Self)> + { + let mut reactor = reactor::Core::new()?; + let options = (reactor.handle(), options).into(); + let client = reactor.run(FutureClient::connect(addr, options))?; + let (proxy, requests) = pair(); + Ok((Client { proxy }, RequestHandler { reactor, client, requests })) + } +} + +impl RequestHandler + where Req: Serialize + 'static, + Resp: Deserialize + 'static, + E: Deserialize + 'static, + S: Service>, + S::Future: 'static, +{ + fn handle_requests(&mut self) { + let RequestHandler { ref mut reactor, ref mut requests, ref mut client } = *self; + let handle = reactor.handle(); + let requests = requests + .map(|result| { + match result { + Ok(req) => req, + // The ClientProxy never sends Err currently + Err(e) => panic!("Unimplemented error handling in RequestHandler: {}", e), + } + }) + .for_each(|(request, response_tx)| { + let request = client.call(request) + .then(move |response| { + response_tx.complete(response); + Ok(()) + }); + handle.spawn(request); + Ok(()) + }); + reactor.run(requests).unwrap(); + } +} + +#[test] +fn handle_requests() { + use futures::future; + + struct Client; + impl Service for Client { + type Request = i32; + type Response = i32; + type Error = ::Error<()>; + type Future = future::FutureResult>; + + fn call(&self, req: i32) -> Self::Future { + future::ok(req) + } + } + + let (request, requests) = ::futures::sync::mpsc::unbounded(); + let reactor = reactor::Core::new().unwrap(); + let client = Client; + let mut request_handler = RequestHandler { reactor, client, requests }; + // Test that `handle_requests` returns when all request senders are dropped. + drop(request); + request_handler.handle_requests(); +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 00000000..79011a35 --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,4 @@ +/// Provides the base client stubs used by the service macro. +pub mod client; +/// Provides the base server boilerplate used by service implementations. +pub mod server; diff --git a/src/sync/server.rs b/src/sync/server.rs new file mode 100644 index 00000000..270c54fb --- /dev/null +++ b/src/sync/server.rs @@ -0,0 +1,77 @@ +use {bincode, future}; +use future::server::{Response, Shutdown}; +use futures::Future; +use serde::{Deserialize, Serialize}; +use std::io; +use std::net::SocketAddr; +use tokio_core::reactor; +use tokio_service::NewService; +#[cfg(feature = "tls")] +use native_tls_inner::TlsAcceptor; + +/// Additional options to configure how the server operates. +#[derive(Default)] +pub struct Options { + opts: future::server::Options, +} + +impl Options { + /// Set the `TlsAcceptor` + #[cfg(feature = "tls")] + pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { + self.opts = self.opts.tls(tls_acceptor); + self + } +} + +/// A handle to a bound server. Must be run to start serving requests. +#[must_use = "A server does nothing until `run` is called."] +pub struct Handle { + reactor: reactor::Core, + handle: future::server::Handle, + server: Box>, +} + +impl Handle { + #[doc(hidden)] + pub fn listen(new_service: S, + addr: SocketAddr, + options: Options) + -> io::Result + where S: NewService, + Response = Response, + Error = io::Error> + 'static, + Req: Deserialize + 'static, + Resp: Serialize + 'static, + E: Serialize + 'static + { + let reactor = reactor::Core::new()?; + let (handle, server) = + future::server::Handle::listen(new_service, addr, &reactor.handle(), options.opts)?; + let server = Box::new(server); + Ok(Handle { + reactor: reactor, + handle: handle, + server: server, + }) + } + + /// Runs the server on the current thread, blocking indefinitely. + pub fn run(mut self) { + trace!("Running..."); + match self.reactor.run(self.server) { + Ok(()) => debug!("Server successfully shutdown."), + Err(()) => debug!("Server shutdown due to error."), + } + } + + /// Returns a hook for shutting down the server. + pub fn shutdown(&self) -> Shutdown { + self.handle.shutdown().clone() + } + + /// The socket address the server is bound to. + pub fn addr(&self) -> SocketAddr { + self.handle.addr() + } +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 00000000..063ed808 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,39 @@ +/// TLS-specific functionality for clients. +pub mod client { + use native_tls::{Error, TlsConnector}; + + /// TLS context for client + pub struct Context { + /// Domain to connect to + pub domain: String, + /// TLS connector + pub tls_connector: TlsConnector, + } + + impl Context { + /// Try to construct a new `Context`. + /// + /// The provided domain will be used for both + /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname + /// validation. + pub fn new>(domain: S) -> Result { + Ok(Context { + domain: domain.into(), + tls_connector: TlsConnector::builder()?.build()?, + }) + } + + /// Construct a new `Context` using the provided domain and `TlsConnector` + /// + /// The domain will be used for both + /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname + /// validation. + pub fn from_connector>(domain: S, tls_connector: TlsConnector) -> Self { + Context { + domain: domain.into(), + tls_connector: tls_connector, + } + } + } +} +