diff --git a/Makefile b/Makefile index fcaccfb..99fea6e 100644 --- a/Makefile +++ b/Makefile @@ -35,6 +35,7 @@ run-client: build .PHONY: e2e e2e: build ./tests/e2e/test_close_server_gracefully.sh + ./tests/e2e/test_client_close_when_server_close.sh ./tests/e2e/test_basic_tcp.sh ./tests/e2e/test_tcp_local_server_not_start.sh ./tests/e2e/test_tcp_with_tunnel_http_server.sh diff --git a/README.md b/README.md index 9481b42..c5494fd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Castle [![CI Test](https://github.com/openosaka/castled/actions/workflows/ci.yaml/badge.svg)](https://github.com/openosaka/castled/actions/workflows/ci.yaml) +[![License](https://img.shields.io/crates/l/castled)](https://github.com/openosaka/castled/blob/main/LICENSE) Castle is a simple tunnel based on GRPC that allows you to expose your local services to the internet, but it's **mainly designed for 🌟testing and ✨development purposes**. diff --git a/examples/crawler.rs b/examples/crawler.rs index 1db37d1..1152b52 100644 --- a/examples/crawler.rs +++ b/examples/crawler.rs @@ -87,7 +87,7 @@ async fn main() -> anyhow::Result<()> { false, 0, ), - shutdown.wait_shutdown_triggered(), + shutdown.clone(), ) .await?; // ### call the proxy @@ -105,7 +105,7 @@ async fn main() -> anyhow::Result<()> { let response2_text = response2.text().await?; assert_eq!(response1_text, response2_text); - shutdown.trigger_shutdown(()).unwrap(); + shutdown.trigger_shutdown(0).unwrap(); Ok(()) } diff --git a/src/bin/castle.rs b/src/bin/castle.rs index 44b80ca..d3cc1fa 100644 --- a/src/bin/castle.rs +++ b/src/bin/castle.rs @@ -82,7 +82,7 @@ async fn main() -> anyhow::Result<()> { let client = Client::new(args.server_addr); let tunnel; - let shutdown: ShutdownManager<()> = ShutdownManager::new(); + let shutdown: ShutdownManager = ShutdownManager::new(); let wait_complete = shutdown.wait_shutdown_complete(); match args.command { @@ -130,9 +130,7 @@ async fn main() -> anyhow::Result<()> { } } - let entrypoint = client - .start_tunnel(tunnel, shutdown.wait_shutdown_triggered()) - .await?; + let entrypoint = client.start_tunnel(tunnel, shutdown.clone()).await?; info!("Entrypoint: {:?}", entrypoint); @@ -142,12 +140,11 @@ async fn main() -> anyhow::Result<()> { panic!("Failed to listen for the ctrl-c signal: {:?}", e); } info!("Received ctrl-c signal. Shutting down..."); - shutdown.trigger_shutdown(()).unwrap(); + shutdown.trigger_shutdown(0).unwrap(); }); - wait_complete.await; - - Ok(()) + let code = wait_complete.await; + std::process::exit(code as i32) } async fn parse_socket_addr(local_addr: &str, port: u16) -> anyhow::Result { diff --git a/src/client/client.rs b/src/client/client.rs index cd50bfc..a2bac82 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use async_shutdown::ShutdownSignal; +use async_shutdown::{ShutdownManager, ShutdownSignal}; use std::net::SocketAddr; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Channel, Response, Status, Streaming}; @@ -54,27 +54,29 @@ impl Client { /// let client = Client::new("127.0.0.1:6100".parse().unwrap()); /// let tunnel = new_tcp_tunnel(String::from("my-tunnel"), SocketAddr::from(([127, 0, 0, 1], 8971)), 8080); /// let shutdown = ShutdownManager::new(); - /// let entrypoint = client.start_tunnel(tunnel, shutdown.wait_shutdown_triggered()).await.unwrap(); + /// let entrypoint = client.start_tunnel(tunnel, shutdown.clone()).await.unwrap(); /// println!("entrypoint: {:?}", entrypoint); + /// shutdown.wait_shutdown_complete().await; /// } /// ``` pub async fn start_tunnel( self, tunnel: Tunnel, - shutdown: ShutdownSignal<()>, + shutdown: ShutdownManager, ) -> Result> { let (entrypoint_tx, entrypoint_rx) = oneshot::channel(); tokio::spawn(async move { let run_tunnel = self.handle_tunnel( - shutdown.clone(), + shutdown.wait_shutdown_triggered(), tunnel, Some(move |entrypoint| { let _ = entrypoint_tx.send(entrypoint); }), ); + tokio::select! { - _ = shutdown => { + _ = shutdown.wait_shutdown_triggered() => { debug!("cancelling tcp tunnel"); } result = run_tunnel => match result { @@ -83,9 +85,11 @@ impl Client { }, Err(err) => { error!(err = ?err, "tunnel closed unexpectedly"); + return shutdown.trigger_shutdown_token(1); }, } } + shutdown.trigger_shutdown_token(0) }); let entrypoint = entrypoint_rx @@ -99,7 +103,7 @@ impl Client { /// we treat the tunnel has been established successfully after we receive the init command. async fn wait_until_registered( &self, - shutdown: ShutdownSignal<()>, + shutdown: ShutdownSignal, control_stream: &mut Streaming, ) -> Result> { select! { @@ -168,7 +172,7 @@ impl Client { #[allow(dead_code)] async fn handle_tunnel( &self, - shutdown: ShutdownSignal<()>, + shutdown: ShutdownSignal, tunnel: Tunnel, hook: Option) + Send + 'static>, ) -> Result<()> { @@ -198,7 +202,7 @@ impl Client { #[instrument(skip(self, shutdown, rpc_client, register_resp, hook))] async fn handle_control_stream( &self, - shutdown: ShutdownSignal<()>, + shutdown: ShutdownSignal, rpc_client: TunnelServiceClient, register_resp: tonic::Response>, local_endpoint: SocketAddr, @@ -226,7 +230,7 @@ impl Client { async fn start_streaming( &self, - shutdown: ShutdownSignal<()>, + shutdown: ShutdownSignal, control_stream: &mut Streaming, rpc_client: TunnelServiceClient, local_endpoint: SocketAddr, diff --git a/tests/e2e/test_basic_tcp.sh b/tests/e2e/test_basic_tcp.sh index 1ed5b54..085f854 100755 --- a/tests/e2e/test_basic_tcp.sh +++ b/tests/e2e/test_basic_tcp.sh @@ -18,9 +18,7 @@ trap cleanup EXIT # Start the tunnel server exec ./target/debug/castled & server_pid=$! - -# Give the server some time to start -sleep 1 +sleep 0.2 # Start the tunnel client exec ./target/debug/castle tcp 12348 --remote-port 9992 & @@ -28,8 +26,7 @@ client_pid=$! # Start the nc TCP server exec nc -l -p 12348 > actual.txt & # it closes with nc's timeout - -sleep 1 +sleep 0.2 # quit after 1 second # TODO(sword): we don't require this timeout, diff --git a/tests/e2e/test_client_close_when_server_close.sh b/tests/e2e/test_client_close_when_server_close.sh new file mode 100755 index 0000000..a0cb348 --- /dev/null +++ b/tests/e2e/test_client_close_when_server_close.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -x + +# Start the tunnel server +exec ./target/debug/castled & +server_pid=$! +sleep 1 + +exec ./target/debug/castle tcp 12348 --remote-port 9992 & +client_pid=$! + +sleep 1 + +kill -SIGINT $server_pid + +# Wait for the client process to exit +wait $client_pid + +# Check the exit code should be 1 +if [ $? -ne 1 ]; then + echo "Test failed" + exit 1 +fi diff --git a/tests/lib.rs b/tests/lib.rs index f4f8ed9..1d44846 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -36,7 +36,7 @@ async fn client_register_tcp() { ..Default::default() }) .await; - let close_client = server.cancel.clone().wait_shutdown_triggered(); + let close_client = server.cancel.clone(); let remote_port = free_port().unwrap(); let control_addr = server.control_addr().clone(); @@ -71,7 +71,7 @@ async fn client_register_tcp() { server.vhttp_port, ); - server.cancel.trigger_shutdown(()).unwrap(); + server.cancel.trigger_shutdown(0).unwrap(); let client_exit = tokio::join!(client_handler); assert!(client_exit.0.is_ok()); @@ -96,14 +96,14 @@ async fn client_register_and_close_then_register_again() { SocketAddr::from(([127, 0, 0, 1], 8971)), remote_port, ), - close_client.wait_shutdown_triggered(), + close_client.clone(), ) .await; }); tokio::spawn(async move { sleep(tokio::time::Duration::from_millis(300)).await; - shutdown.trigger_shutdown(()); + shutdown.trigger_shutdown(0); }); let client_exit = tokio::join!(client_handler); @@ -122,14 +122,14 @@ async fn client_register_and_close_then_register_again() { SocketAddr::from(([127, 0, 0, 1], 8971)), /* no matter */ remote_port, ), - close_client.wait_shutdown_triggered(), + close_client, ) .await; }); tokio::spawn(async move { sleep(tokio::time::Duration::from_millis(300)).await; - shutdown.trigger_shutdown(()); + shutdown.trigger_shutdown(0); }); let client_exit = tokio::join!(client_handler); @@ -170,7 +170,7 @@ async fn register_http_tunnel_with_subdomain() { false, 0, ), - close_client.wait_shutdown_triggered(), + close_client, ) .await; wait_client_register.send(()).unwrap(); @@ -195,7 +195,7 @@ async fn register_http_tunnel_with_subdomain() { ); assert_eq!(response.text().await.unwrap(), mock_body); - server.cancel.trigger_shutdown(()).unwrap(); + server.cancel.trigger_shutdown(0).unwrap(); let client_exit = tokio::join!(client_handler); assert!(client_exit.0.is_ok()); @@ -443,9 +443,7 @@ async fn test_assigned_entrypoint() { let client_handler = tokio::spawn(async move { let client = Client::new(control_addr); - let entrypoint = client - .start_tunnel(tunnel.tunnel, close_client.wait_shutdown_triggered()) - .await; + let entrypoint = client.start_tunnel(tunnel.tunnel, close_client).await; assert!(entrypoint.is_ok()); let entrypoint = entrypoint.unwrap(); assert_eq!(entrypoint, tunnel.expected); @@ -453,7 +451,7 @@ async fn test_assigned_entrypoint() { tokio::spawn(async move { sleep(tokio::time::Duration::from_millis(100)).await; - shutdown.trigger_shutdown(()).unwrap(); + shutdown.trigger_shutdown(0).unwrap(); }); let _ = tokio::join!(client_handler); @@ -464,7 +462,7 @@ async fn test_assigned_entrypoint() { struct TestServer { control_port: u16, vhttp_port: u16, - cancel: ShutdownManager<()>, + cancel: ShutdownManager, } impl TestServer {