Skip to content

Commit

Permalink
quit client if server close (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
sword-jin authored Jul 20, 2024
1 parent 3ee9463 commit ac39eff
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 37 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ run-client: build
.PHONY: e2e
e2e: build
./tests/e2e/test_close_server_gracefully.sh
./tests/e2e/test_client_close_when_server_close.sh
./tests/e2e/test_basic_tcp.sh
./tests/e2e/test_tcp_local_server_not_start.sh
./tests/e2e/test_tcp_with_tunnel_http_server.sh
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Castle

[![CI Test](https://github.com/openosaka/castled/actions/workflows/ci.yaml/badge.svg)](https://github.com/openosaka/castled/actions/workflows/ci.yaml)
[![License](https://img.shields.io/crates/l/castled)](https://github.com/openosaka/castled/blob/main/LICENSE)

Castle is a simple tunnel based on GRPC that allows you to expose your local services to the internet,
but it's **mainly designed for 🌟testing and ✨development purposes**.
Expand Down
4 changes: 2 additions & 2 deletions examples/crawler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async fn main() -> anyhow::Result<()> {
false,
0,
),
shutdown.wait_shutdown_triggered(),
shutdown.clone(),
)
.await?;
// ### call the proxy
Expand All @@ -105,7 +105,7 @@ async fn main() -> anyhow::Result<()> {
let response2_text = response2.text().await?;
assert_eq!(response1_text, response2_text);

shutdown.trigger_shutdown(()).unwrap();
shutdown.trigger_shutdown(0).unwrap();
Ok(())
}

Expand Down
13 changes: 5 additions & 8 deletions src/bin/castle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async fn main() -> anyhow::Result<()> {

let client = Client::new(args.server_addr);
let tunnel;
let shutdown: ShutdownManager<()> = ShutdownManager::new();
let shutdown: ShutdownManager<i8> = ShutdownManager::new();
let wait_complete = shutdown.wait_shutdown_complete();

match args.command {
Expand Down Expand Up @@ -130,9 +130,7 @@ async fn main() -> anyhow::Result<()> {
}
}

let entrypoint = client
.start_tunnel(tunnel, shutdown.wait_shutdown_triggered())
.await?;
let entrypoint = client.start_tunnel(tunnel, shutdown.clone()).await?;

info!("Entrypoint: {:?}", entrypoint);

Expand All @@ -142,12 +140,11 @@ async fn main() -> anyhow::Result<()> {
panic!("Failed to listen for the ctrl-c signal: {:?}", e);
}
info!("Received ctrl-c signal. Shutting down...");
shutdown.trigger_shutdown(()).unwrap();
shutdown.trigger_shutdown(0).unwrap();
});

wait_complete.await;

Ok(())
let code = wait_complete.await;
std::process::exit(code as i32)
}

async fn parse_socket_addr(local_addr: &str, port: u16) -> anyhow::Result<SocketAddr> {
Expand Down
22 changes: 13 additions & 9 deletions src/client/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::{Context, Result};
use async_shutdown::ShutdownSignal;
use async_shutdown::{ShutdownManager, ShutdownSignal};
use std::net::SocketAddr;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tonic::{transport::Channel, Response, Status, Streaming};
Expand Down Expand Up @@ -54,27 +54,29 @@ impl Client {
/// let client = Client::new("127.0.0.1:6100".parse().unwrap());
/// let tunnel = new_tcp_tunnel(String::from("my-tunnel"), SocketAddr::from(([127, 0, 0, 1], 8971)), 8080);
/// let shutdown = ShutdownManager::new();
/// let entrypoint = client.start_tunnel(tunnel, shutdown.wait_shutdown_triggered()).await.unwrap();
/// let entrypoint = client.start_tunnel(tunnel, shutdown.clone()).await.unwrap();
/// println!("entrypoint: {:?}", entrypoint);
/// shutdown.wait_shutdown_complete().await;
/// }
/// ```
pub async fn start_tunnel(
self,
tunnel: Tunnel,
shutdown: ShutdownSignal<()>,
shutdown: ShutdownManager<i8>,
) -> Result<Vec<String>> {
let (entrypoint_tx, entrypoint_rx) = oneshot::channel();

tokio::spawn(async move {
let run_tunnel = self.handle_tunnel(
shutdown.clone(),
shutdown.wait_shutdown_triggered(),
tunnel,
Some(move |entrypoint| {
let _ = entrypoint_tx.send(entrypoint);
}),
);

tokio::select! {
_ = shutdown => {
_ = shutdown.wait_shutdown_triggered() => {
debug!("cancelling tcp tunnel");
}
result = run_tunnel => match result {
Expand All @@ -83,9 +85,11 @@ impl Client {
},
Err(err) => {
error!(err = ?err, "tunnel closed unexpectedly");
return shutdown.trigger_shutdown_token(1);
},
}
}
shutdown.trigger_shutdown_token(0)
});

let entrypoint = entrypoint_rx
Expand All @@ -99,7 +103,7 @@ impl Client {
/// we treat the tunnel has been established successfully after we receive the init command.
async fn wait_until_registered(
&self,
shutdown: ShutdownSignal<()>,
shutdown: ShutdownSignal<i8>,
control_stream: &mut Streaming<Control>,
) -> Result<Vec<String>> {
select! {
Expand Down Expand Up @@ -168,7 +172,7 @@ impl Client {
#[allow(dead_code)]
async fn handle_tunnel(
&self,
shutdown: ShutdownSignal<()>,
shutdown: ShutdownSignal<i8>,
tunnel: Tunnel,
hook: Option<impl FnOnce(Vec<String>) + Send + 'static>,
) -> Result<()> {
Expand Down Expand Up @@ -198,7 +202,7 @@ impl Client {
#[instrument(skip(self, shutdown, rpc_client, register_resp, hook))]
async fn handle_control_stream(
&self,
shutdown: ShutdownSignal<()>,
shutdown: ShutdownSignal<i8>,
rpc_client: TunnelServiceClient<Channel>,
register_resp: tonic::Response<Streaming<Control>>,
local_endpoint: SocketAddr,
Expand Down Expand Up @@ -226,7 +230,7 @@ impl Client {

async fn start_streaming(
&self,
shutdown: ShutdownSignal<()>,
shutdown: ShutdownSignal<i8>,
control_stream: &mut Streaming<Control>,
rpc_client: TunnelServiceClient<Channel>,
local_endpoint: SocketAddr,
Expand Down
7 changes: 2 additions & 5 deletions tests/e2e/test_basic_tcp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ trap cleanup EXIT
# Start the tunnel server
exec ./target/debug/castled &
server_pid=$!

# Give the server some time to start
sleep 1
sleep 0.2

# Start the tunnel client
exec ./target/debug/castle tcp 12348 --remote-port 9992 &
client_pid=$!

# Start the nc TCP server
exec nc -l -p 12348 > actual.txt & # it closes with nc's timeout

sleep 1
sleep 0.2

# quit after 1 second
# TODO(sword): we don't require this timeout,
Expand Down
23 changes: 23 additions & 0 deletions tests/e2e/test_client_close_when_server_close.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
set -x

# Start the tunnel server
exec ./target/debug/castled &
server_pid=$!
sleep 1

exec ./target/debug/castle tcp 12348 --remote-port 9992 &
client_pid=$!

sleep 1

kill -SIGINT $server_pid

# Wait for the client process to exit
wait $client_pid

# Check the exit code should be 1
if [ $? -ne 1 ]; then
echo "Test failed"
exit 1
fi
24 changes: 11 additions & 13 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn client_register_tcp() {
..Default::default()
})
.await;
let close_client = server.cancel.clone().wait_shutdown_triggered();
let close_client = server.cancel.clone();
let remote_port = free_port().unwrap();
let control_addr = server.control_addr().clone();

Expand Down Expand Up @@ -71,7 +71,7 @@ async fn client_register_tcp() {
server.vhttp_port,
);

server.cancel.trigger_shutdown(()).unwrap();
server.cancel.trigger_shutdown(0).unwrap();

let client_exit = tokio::join!(client_handler);
assert!(client_exit.0.is_ok());
Expand All @@ -96,14 +96,14 @@ async fn client_register_and_close_then_register_again() {
SocketAddr::from(([127, 0, 0, 1], 8971)),
remote_port,
),
close_client.wait_shutdown_triggered(),
close_client.clone(),
)
.await;
});

tokio::spawn(async move {
sleep(tokio::time::Duration::from_millis(300)).await;
shutdown.trigger_shutdown(());
shutdown.trigger_shutdown(0);
});

let client_exit = tokio::join!(client_handler);
Expand All @@ -122,14 +122,14 @@ async fn client_register_and_close_then_register_again() {
SocketAddr::from(([127, 0, 0, 1], 8971)), /* no matter */
remote_port,
),
close_client.wait_shutdown_triggered(),
close_client,
)
.await;
});

tokio::spawn(async move {
sleep(tokio::time::Duration::from_millis(300)).await;
shutdown.trigger_shutdown(());
shutdown.trigger_shutdown(0);
});

let client_exit = tokio::join!(client_handler);
Expand Down Expand Up @@ -170,7 +170,7 @@ async fn register_http_tunnel_with_subdomain() {
false,
0,
),
close_client.wait_shutdown_triggered(),
close_client,
)
.await;
wait_client_register.send(()).unwrap();
Expand All @@ -195,7 +195,7 @@ async fn register_http_tunnel_with_subdomain() {
);
assert_eq!(response.text().await.unwrap(), mock_body);

server.cancel.trigger_shutdown(()).unwrap();
server.cancel.trigger_shutdown(0).unwrap();

let client_exit = tokio::join!(client_handler);
assert!(client_exit.0.is_ok());
Expand Down Expand Up @@ -443,17 +443,15 @@ async fn test_assigned_entrypoint() {

let client_handler = tokio::spawn(async move {
let client = Client::new(control_addr);
let entrypoint = client
.start_tunnel(tunnel.tunnel, close_client.wait_shutdown_triggered())
.await;
let entrypoint = client.start_tunnel(tunnel.tunnel, close_client).await;
assert!(entrypoint.is_ok());
let entrypoint = entrypoint.unwrap();
assert_eq!(entrypoint, tunnel.expected);
});

tokio::spawn(async move {
sleep(tokio::time::Duration::from_millis(100)).await;
shutdown.trigger_shutdown(()).unwrap();
shutdown.trigger_shutdown(0).unwrap();
});

let _ = tokio::join!(client_handler);
Expand All @@ -464,7 +462,7 @@ async fn test_assigned_entrypoint() {
struct TestServer {
control_port: u16,
vhttp_port: u16,
cancel: ShutdownManager<()>,
cancel: ShutdownManager<i8>,
}

impl TestServer {
Expand Down

0 comments on commit ac39eff

Please sign in to comment.