Skip to content

Commit

Permalink
Refactor server spawning to be synchronous and infallible (#2097)
Browse files Browse the repository at this point in the history
* Make simple server spawn methods synchronous

Perform all IO operations in the spawned task.

* Make gRPC server spawn synchronous

Perform all IO operations in the spawned task.

* Make gRPC server spawn infallible

Return an error only from the spawned task.

* Start metrics outside of `async` block

There's no need to delay it.

* Make `spawn_{grpc,simple}` infallible

There aren't any failure scenarios for them.

* Remove unnecessary bindings

The fields can be cloned directly in the `async` block, just like what
is done with the `server_config.internal_network` value.

* Move out server `spawn` calls from `async` block

There's no need to keep them in the block, now that they are
synchronous.

* Refactor to use `FuturesUnordered`

Make it explicit that the order of the futures doesn't matter.
  • Loading branch information
jvff authored Jun 4, 2024
1 parent 805c4fc commit 7a6da04
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 103 deletions.
38 changes: 21 additions & 17 deletions linera-rpc/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ pub struct GrpcServer<S> {

pub struct GrpcServerHandle {
_complete: Sender<()>,
handle: JoinHandle<Result<(), tonic::transport::Error>>,
handle: JoinHandle<Result<(), GrpcError>>,
}

impl GrpcServerHandle {
pub async fn join(self) -> Result<(), GrpcError> {
Ok(self.handle.await??)
self.handle.await?
}
}

Expand Down Expand Up @@ -177,22 +177,20 @@ where
ViewError: From<S::ContextError>,
{
#[allow(clippy::too_many_arguments)]
pub async fn spawn(
pub fn spawn(
host: String,
port: u16,
state: WorkerState<S>,
shard_id: ShardId,
internal_network: ValidatorInternalNetworkConfig,
cross_chain_config: CrossChainConfig,
notification_config: NotificationConfig,
) -> Result<GrpcServerHandle, GrpcError> {
) -> GrpcServerHandle {
info!(
"spawning gRPC server on {}:{} for shard {}",
host, port, shard_id
);

let server_address = SocketAddr::from((IpAddr::from_str(&host)?, port));

let (cross_chain_sender, cross_chain_receiver) =
mpsc::channel(cross_chain_config.queue_size);

Expand Down Expand Up @@ -232,9 +230,6 @@ where
let (complete, receiver) = futures::channel::oneshot::channel();

let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
health_reporter
.set_serving::<ValidatorWorkerServer<Self>>()
.await;

let grpc_server = GrpcServer {
state,
Expand All @@ -248,11 +243,17 @@ where
.max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE)
.max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE);

let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(crate::FILE_DESCRIPTOR_SET)
.build()?;
let handle = tokio::spawn(async move {
let server_address = SocketAddr::from((IpAddr::from_str(&host)?, port));

let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(crate::FILE_DESCRIPTOR_SET)
.build()?;

health_reporter
.set_serving::<ValidatorWorkerServer<Self>>()
.await;

let handle = tokio::spawn(
tonic::transport::Server::builder()
.layer(
ServiceBuilder::new()
Expand All @@ -262,13 +263,16 @@ where
.add_service(health_service)
.add_service(reflection_service)
.add_service(worker_node)
.serve_with_shutdown(server_address, receiver.map(|_| ())),
);
.serve_with_shutdown(server_address, receiver.map(|_| ()))
.await?;

Ok(())
});

Ok(GrpcServerHandle {
GrpcServerHandle {
_complete: complete,
handle,
})
}
}

/// Continuously waits for receiver to receive a notification which is then sent to
Expand Down
6 changes: 3 additions & 3 deletions linera-rpc/src/simple/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::{io, time::Duration};
use std::time::Duration;

use async_trait::async_trait;
use futures::{channel::mpsc, stream::StreamExt};
Expand Down Expand Up @@ -134,7 +134,7 @@ where
}
}

pub async fn spawn(self) -> Result<ServerHandle, io::Error> {
pub fn spawn(self) -> ServerHandle {
info!(
"Listening to {:?} traffic on {}:{}",
self.network.protocol, self.host, self.port
Expand All @@ -161,7 +161,7 @@ where
cross_chain_sender,
};
// Launch server for the appropriate protocol.
protocol.spawn_server(address, state).await
protocol.spawn_server(address, state)
}
}

Expand Down
24 changes: 10 additions & 14 deletions linera-rpc/src/simple/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,19 @@ impl TransportProtocol {
}

/// Runs a server for this protocol and the given message handler.
pub async fn spawn_server<S>(
pub fn spawn_server<S>(
self,
address: impl ToSocketAddrs,
address: impl ToSocketAddrs + Send + 'static,
state: S,
) -> Result<ServerHandle, std::io::Error>
) -> ServerHandle
where
S: MessageHandler + Send + 'static,
{
let handle = match self {
Self::Udp => {
let socket = UdpSocket::bind(address).await?;
tokio::spawn(Self::run_udp_server(socket, state))
}
Self::Tcp => {
let listener = TcpListener::bind(address).await?;
tokio::spawn(Self::run_tcp_server(listener, state))
}
Self::Udp => tokio::spawn(Self::run_udp_server(address, state)),
Self::Tcp => tokio::spawn(Self::run_tcp_server(address, state)),
};
Ok(ServerHandle { handle })
ServerHandle { handle }
}
}

Expand Down Expand Up @@ -194,10 +188,11 @@ impl ConnectionPool for UdpConnectionPool {

// Server implementation for UDP.
impl TransportProtocol {
async fn run_udp_server<S>(socket: UdpSocket, state: S) -> Result<(), std::io::Error>
async fn run_udp_server<S>(address: impl ToSocketAddrs, state: S) -> Result<(), std::io::Error>
where
S: MessageHandler + Send + 'static,
{
let socket = UdpSocket::bind(address).await?;
let (udp_sink, mut udp_stream) = UdpFramed::new(socket, Codec).split();
let udp_sink = Arc::new(Mutex::new(udp_sink));
// Track the latest tasks for a given peer. This is used to return answers in the
Expand Down Expand Up @@ -290,10 +285,11 @@ impl ConnectionPool for TcpConnectionPool {

// Server implementation for TCP.
impl TransportProtocol {
async fn run_tcp_server<S>(listener: TcpListener, state: S) -> Result<(), std::io::Error>
async fn run_tcp_server<S>(address: impl ToSocketAddrs, state: S) -> Result<(), std::io::Error>
where
S: MessageHandler + Send + 'static,
{
let listener = TcpListener::bind(address).await?;
let mut accept_stream = stream::try_unfold(listener, |listener| async move {
let (socket, _) = listener.accept().await?;
Ok::<_, io::Error>(Some((socket, listener)))
Expand Down
3 changes: 1 addition & 2 deletions linera-service/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ where

self.public_config
.protocol
.spawn_server(&address, self)
.await?
.spawn_server(address, self)
.join()
.await?;
Ok(())
Expand Down
125 changes: 58 additions & 67 deletions linera-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{path::PathBuf, time::Duration};

use anyhow::bail;
use async_trait::async_trait;
use futures::future::join_all;
use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
use linera_base::crypto::{CryptoRng, KeyPair};
use linera_core::worker::WorkerState;
use linera_execution::{committee::ValidatorName, WasmRuntime, WithWasmDefault};
Expand Down Expand Up @@ -67,95 +67,86 @@ impl ServerContext {
listen_address: &str,
states: Vec<(WorkerState<S>, ShardId, ShardConfig)>,
protocol: simple::TransportProtocol,
) -> Result<(), anyhow::Error>
where
) where
S: Storage + Clone + Send + Sync + 'static,
ViewError: From<S::ContextError>,
{
let handles = FuturesUnordered::new();

let internal_network = self
.server_config
.internal_network
.clone_with_protocol(protocol);

let mut handles = Vec::new();
for (state, shard_id, shard) in states {
let internal_network = internal_network.clone();
let cross_chain_config = self.cross_chain_config.clone();
handles.push(async move {
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(listen_address, port);
}
let server = simple::Server::new(
internal_network,
listen_address.to_string(),
shard.port,
state,
shard_id,
cross_chain_config,
);
let spawned_server = match server.spawn().await {
Ok(server) => server,
Err(err) => {
error!("Failed to start server: {}", err);
return;
}
};
if let Err(err) = spawned_server.join().await {
error!("Server ended with an error: {}", err);
}
});
}
let listen_address = listen_address.to_owned();

join_all(handles).await;
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(&listen_address, port);
}

Ok(())
let server_handle = simple::Server::new(
internal_network,
listen_address,
shard.port,
state,
shard_id,
cross_chain_config,
)
.spawn();

handles.push(
server_handle
.join()
.inspect_err(move |error| {
error!("Error running server for shard {shard_id}: {error:?}")
})
.map(|_| ()),
);
}

handles.collect::<()>().await;
}

async fn spawn_grpc<S>(
&self,
listen_address: &str,
states: Vec<(WorkerState<S>, ShardId, ShardConfig)>,
) -> Result<(), anyhow::Error>
where
) where
S: Storage + Clone + Send + Sync + 'static,
ViewError: From<S::ContextError>,
{
let mut handles = Vec::new();
let handles = FuturesUnordered::new();
for (state, shard_id, shard) in states {
let cross_chain_config = self.cross_chain_config.clone();
let notification_config = self.notification_config.clone();
handles.push(async move {
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(listen_address, port);
}
let spawned_server = match grpc::GrpcServer::spawn(
listen_address.to_string(),
shard.port,
state,
shard_id,
self.server_config.internal_network.clone(),
cross_chain_config,
notification_config,
)
.await
{
Ok(spawned_server) => spawned_server,
Err(err) => {
error!("Failed to start server: {:?}", err);
return;
}
};
if let Err(err) = spawned_server.join().await {
error!("Server ended with an error: {}", err);
}
});
}
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(listen_address, port);
}

join_all(handles).await;
let server_handle = grpc::GrpcServer::spawn(
listen_address.to_string(),
shard.port,
state,
shard_id,
self.server_config.internal_network.clone(),
self.cross_chain_config.clone(),
self.notification_config.clone(),
);

handles.push(
server_handle
.join()
.inspect_err(move |error| {
error!("Error running server for shard {shard_id}: {error:?}")
})
.map(|_| ()),
);
}

Ok(())
handles.collect::<()>().await;
}

#[cfg(with_metrics)]
Expand Down Expand Up @@ -197,10 +188,10 @@ impl Runnable for ServerContext {

match self.server_config.internal_network.protocol {
NetworkProtocol::Simple(protocol) => {
self.spawn_simple(&listen_address, states, protocol).await?
self.spawn_simple(&listen_address, states, protocol).await
}
NetworkProtocol::Grpc(tls_config) => match tls_config {
TlsConfig::ClearText => self.spawn_grpc(&listen_address, states).await?,
TlsConfig::ClearText => self.spawn_grpc(&listen_address, states).await,
TlsConfig::Tls => bail!("TLS not supported between proxy and shards."),
},
};
Expand Down

0 comments on commit 7a6da04

Please sign in to comment.