From 8ff964ef8e94591f4bc4f733050e785fb45b0927 Mon Sep 17 00:00:00 2001 From: Sword Date: Tue, 23 Jul 2024 13:32:37 +0900 Subject: [PATCH] support uploading large file through multipart (#81) * add upload file e2e test * Support upload files(multipart) through http tunnel --- README.md | 1 + src/client/client.rs | 21 +- src/constant.rs | 2 + src/lib.rs | 1 + src/server/control_server.rs | 26 +- src/server/tunnel/http.rs | 255 +++++++++++++----- tests/e2e/file_server/main.go | 65 +++++ .../test_upload_file_through_http_tunnel.sh | 42 +++ tests/e2e/util.sh | 1 + 9 files changed, 335 insertions(+), 79 deletions(-) create mode 100644 src/constant.rs create mode 100644 tests/e2e/file_server/main.go create mode 100755 tests/e2e/test_upload_file_through_http_tunnel.sh diff --git a/README.md b/README.md index c5494fd..78ec5fb 100644 --- a/README.md +++ b/README.md @@ -38,4 +38,5 @@ Basically, this tunnel is primarily for this purpose. If you want to expose your - random subdomain if `--random-subdomain` is specified - random remote port if not specified - support http/1.1 + - Upload file - [ ] support http/2 diff --git a/src/client/client.rs b/src/client/client.rs index fe58670..c1cde91 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -6,13 +6,14 @@ use tonic::{transport::Channel, Response, Status, Streaming}; use tracing::{debug, error, info, instrument, span}; use tokio::{ - io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}, + io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, net::{TcpStream, UdpSocket}, select, sync::{mpsc, oneshot}, }; use crate::{ + constant, io::{StreamingReader, StreamingWriter, TrafficToServerWrapper}, pb::{ self, control::Payload, traffic_to_server, tunnel::Type, @@ -469,14 +470,19 @@ async fn handle_work_traffic( /// 1. remote <=> me /// 2. me <=> local async fn forward_traffic_to_local( - mut local_r: impl AsyncRead + Unpin, + local_r: impl AsyncRead + Unpin, mut local_w: impl AsyncWrite + Unpin, - mut remote_r: StreamingReader, + remote_r: StreamingReader, mut remote_w: StreamingWriter, ) -> Result<()> { let remote_to_me_to_local = async { // read from remote, write to local - match io::copy(&mut remote_r, &mut local_w).await { + match io::copy_buf( + &mut BufReader::with_capacity(constant::DEFAULT_BUF_SIZE, remote_r), + &mut local_w, + ) + .await + { Ok(n) => { debug!("copied {} bytes from remote to local", n); let _ = local_w.shutdown().await; @@ -489,7 +495,12 @@ async fn forward_traffic_to_local( let local_to_me_to_remote = async { // read from local, write to remote - match io::copy(&mut local_r, &mut remote_w).await { + match io::copy_buf( + &mut BufReader::with_capacity(constant::DEFAULT_BUF_SIZE, local_r), + &mut remote_w, + ) + .await + { Ok(n) => { debug!("copied {} bytes from local to remote", n); let _ = remote_w.shutdown().await; diff --git a/src/constant.rs b/src/constant.rs new file mode 100644 index 0000000..7d2a4f9 --- /dev/null +++ b/src/constant.rs @@ -0,0 +1,2 @@ +// use 8K as the default buffer size when transferring io data. +pub(crate) const DEFAULT_BUF_SIZE: usize = 8 * 1024; diff --git a/src/lib.rs b/src/lib.rs index 498675d..2bae815 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub(crate) mod bridge; +pub(crate) mod constant; pub(crate) mod event; pub(crate) mod helper; pub(crate) mod io; diff --git a/src/server/control_server.rs b/src/server/control_server.rs index e6d017e..85b7108 100644 --- a/src/server/control_server.rs +++ b/src/server/control_server.rs @@ -1,7 +1,7 @@ use crate::event::ClientEventResponse; use crate::helper::validate_register_req; use crate::pb::{traffic_to_server, TrafficToServer}; -use crate::{bridge, event}; +use crate::{bridge, constant, event}; use crate::{ io::CancellableReceiver, pb::{ @@ -394,11 +394,25 @@ impl TunnelService for ControlHandler { tokio::select! { // server -> client Some(data) = transfer_rx.recv() => { - outbound_tx - .send(Ok(TrafficToClient { data })) - .await - .context("failed to send traffic to outbound channel") - .unwrap(); + if data.len() <= constant::DEFAULT_BUF_SIZE { + outbound_tx + .send(Ok(TrafficToClient { data })) + .await + .context("failed to send traffic to outbound channel") + .unwrap(); + } else { + // read data chunk by chunk, then send to the client + // if data is empty, the first chunk will be none. + // which means we have no change to send the data. + // so if the length of data is less than 8192, we can send it directly. + for data in data.chunks(constant::DEFAULT_BUF_SIZE) { + outbound_tx + .send(Ok(TrafficToClient { data: data.to_vec() })) + .await + .context("failed to send traffic to outbound channel") + .unwrap(); + } + } } _ = close_sender_listener.cancelled() => { // after connection is removed, this listener will be notified diff --git a/src/server/tunnel/http.rs b/src/server/tunnel/http.rs index 83660b9..9832396 100644 --- a/src/server/tunnel/http.rs +++ b/src/server/tunnel/http.rs @@ -7,7 +7,8 @@ use anyhow::{Context as _, Result}; use bytes::{BufMut as _, Bytes}; use dashmap::DashMap; use futures::TryStreamExt; -use http::HeaderValue; +use http::response::Builder; +use http::{HeaderValue, StatusCode}; use http_body::Frame; use http_body_util::combinators::BoxBody; use http_body_util::{BodyDataStream, BodyExt, Full, StreamBody}; @@ -15,6 +16,7 @@ use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{body::Incoming, Request, Response}; use hyper_util::rt::TokioIo; +use std::cell::Cell; use std::convert::Infallible; use std::io::Write; use std::sync::Arc; @@ -188,7 +190,7 @@ impl Http { // response to user http request by sending data to outbound_rx let (body_tx, body_rx) = mpsc::channel::, Infallible>>(1024); - let (header_tx, header_rx) = oneshot::channel::>(); + let (header_tx, header_rx) = oneshot::channel::>(); let client_cancel_receiver = bridge.client_cancel_receiver.clone(); @@ -209,50 +211,145 @@ impl Http { .unwrap() } header = header_rx => { - let header_buf = header.unwrap(); - let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; - let mut resp = httparse::Response::new(&mut headers); - if let Err(err) = resp.parse(&header_buf) { - error!(err = ?err, "failed to parse response header"); - return Response::builder() - .status(500) - .body(BoxBody::new(Full::new(Bytes::from_static(b"failed to parse response header")))) - .unwrap(); - } - if resp.code.is_none() { - return Response::builder() - .status(500) - .body(BoxBody::new(Full::new(Bytes::from_static(b"invalid response header: no status")))) - .unwrap(); + // get response builder from header + let http_builder = header.unwrap(); + match http_builder { + Ok(http_builder) => { + let stream = ReceiverStream::new(body_rx); + let body = BoxBody::new(StreamBody::new(stream)); + http_builder.body(body).unwrap() + }, + Err(err) => { + error!(err = ?err, "failed to get response builder"); + Response::builder() + .status(500) + .body(BoxBody::new(Full::new(Bytes::from_static(b"failed to get response builder")))) + .unwrap() + }, } - let mut http_builder = Response::builder().status(resp.code.unwrap()) - .version(match resp.version { - Some(0) => http::Version::HTTP_10, - _ => http::Version::HTTP_11, - }); - for header in resp.headers { - http_builder = http_builder.header(header.name, header.value); - } + // let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + // let mut resp = httparse::Response::new(&mut headers); + // if let Err(err) = resp.parse(&header_buf) { + // error!(err = ?err, "failed to parse response header"); + // return Response::builder() + // .status(500) + // .body(BoxBody::new(Full::new(Bytes::from_static(b"failed to parse response header")))) + // .unwrap(); + // } + // if resp.code.is_none() { + // error!("invalid response header: no status"); + // return Response::builder() + // .status(500) + // .body(BoxBody::new(Full::new(Bytes::from_static(b"invalid response header: no status")))) + // .unwrap(); + // } + + // let mut http_builder = Response::builder().status(resp.code.unwrap()) + // .version(match resp.version { + // Some(0) => http::Version::HTTP_10, + // _ => http::Version::HTTP_11, + // }); + // for header in resp.headers { + // http_builder = http_builder.header(header.name, header.value); + // } + } + } + } +} + +struct ResponseHeaderScanner { + buf_cell: Cell>, + ended: bool, + pos: usize, +} + +impl ResponseHeaderScanner { + fn new(buf_size: usize) -> Self { + Self { + buf_cell: Cell::new(Vec::with_capacity(buf_size)), + ended: false, + pos: 0, + } + } - let stream = ReceiverStream::new(body_rx); - let body = BoxBody::new(StreamBody::new(stream)); - http_builder.body(body).unwrap() + fn split_parts(&mut self) -> (&[u8], &[u8]) { + let buffer = self.buf_cell.get_mut(); + buffer.split_at(self.pos) + } + + /// Helper function to find the end of the header (\r\n\r\n) in the buffer. + /// Returns true if the end of the header if it found. + fn scan(&mut self, new_buf: Vec) -> bool { + let buffer = self.buf_cell.get_mut(); + buffer.extend_from_slice(&new_buf); + for i in self.pos..buffer.len() - 3 { + if buffer[i] == b'\r' + && buffer[i + 1] == b'\n' + && buffer[i + 2] == b'\r' + && buffer[i + 3] == b'\n' + { + self.pos = i + 4; + self.ended = true; + return true; } } + false + } + + fn parse(&mut self) -> Result)>> { + let (header_part, body_part) = self.split_parts(); + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + let mut resp = httparse::Response::new(&mut headers); + resp.parse(header_part)?; + if resp.code.is_none() { + return Err(anyhow::anyhow!("invalid response header: no status")); + } + + if resp.code.unwrap() == StatusCode::CONTINUE { + // discard 100-continue, because the hyper server has handled it. + // continue to look for the next header end. + let mut new_header = Vec::with_capacity(MAX_HEADER_SIZE); + new_header.extend_from_slice(body_part); + self.reset(new_header); + return Ok(None); + } + + let mut http_builder = + Response::builder() + .status(resp.code.unwrap()) + .version(match resp.version { + Some(0) => http::Version::HTTP_10, + _ => http::Version::HTTP_11, + }); + for header in resp.headers { + http_builder = http_builder.header(header.name, header.value); + } + Ok(Some((http_builder, body_part.to_vec()))) + } + + fn move_back(&mut self, n: usize) { + self.pos = self.pos.saturating_sub(n); + } + + fn reset(&mut self, buffer: Vec) { + self.ended = false; + self.pos = 0; + self.buf_cell.set(buffer); } } async fn receive_response( mut data_receiver: mpsc::Receiver, - mut header_tx: Option>>, + mut header_tx: Option>>, body_tx: mpsc::Sender, Infallible>>, client_cancel_receiver: CancellationToken, remove_bridge_sender: CancellationToken, ) { - let mut header = Vec::with_capacity(MAX_HEADER_SIZE); - let mut header_ended = false; - let mut scan_buf_start = 0; + // let mut header_cell = Cell::new(Vec::with_capacity(MAX_HEADER_SIZE)); + // let mut header_ended = false; + // let mut scan_buf_start = 0; + let mut header_scanner = ResponseHeaderScanner::new(MAX_HEADER_SIZE); loop { tokio::select! { @@ -263,7 +360,7 @@ async fn receive_response( match data { BridgeData::Data(data) => { if data.is_empty() { - if !header_ended { + if !header_scanner.ended { // this spawn will drop the body_tx and header_tx error!("unexpected empty data, header not ended"); } @@ -271,20 +368,24 @@ async fn receive_response( break; } - if !header_ended { - header.extend_from_slice(&data); - if let Some(header_end_pos) = find_header_end(&header, scan_buf_start) { - header_ended = true; - let (header_part, body_part) = header.split_at(header_end_pos + 1); - if let Some(header_tx) = header_tx.take() { - header_tx.send(header_part.to_vec()).unwrap(); + if !header_scanner.ended { + if header_scanner.scan(data) { + match header_scanner.parse() { + Err(err) => { + header_tx.take().unwrap().send(Err(err)).unwrap(); + break; + } + Ok(None) => { + continue + } + Ok(Some((http_builder, body_part))) => { + header_tx.take().unwrap().send(Ok(http_builder)).unwrap(); + let frame = Frame::data(Bytes::from(body_part.to_vec())); + let _ = body_tx.send(Ok(frame)).await; + } } - - let frame = Frame::data(Bytes::from(body_part.to_vec())); - let _ = body_tx.send(Ok(frame)).await; - } else { - scan_buf_start = header.len().saturating_sub(3); // min: 0 + header_scanner.move_back(4); // min: 0 } } else { let frame = Frame::data(Bytes::from(data)); @@ -415,21 +516,6 @@ impl DynamicRegistry { } } -/// Helper function to find the end of the header (\r\n\r\n) in the buffer. -/// Returns the position of the end of the header if found, otherwise None. -fn find_header_end(buffer: &[u8], start: usize) -> Option { - for i in start..buffer.len() - 3 { - if buffer[i] == b'\r' - && buffer[i + 1] == b'\n' - && buffer[i + 2] == b'\r' - && buffer[i + 3] == b'\n' - { - return Some(i + 3); - } - } - None -} - #[cfg(test)] mod test { use std::pin::Pin; @@ -496,7 +582,7 @@ mod test { + Sync + 'a, >, - expected_header: &'a [u8], + expected_header: &'a str, expected_body: &'a [u8], } @@ -507,12 +593,12 @@ mod test { Box::pin(async move { let _ = tx .send(BridgeData::Data( - b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello".to_vec(), + b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nhello".to_vec(), )) .await; }) }), - expected_header: b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n", + expected_header: "HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\n", expected_body: b"hello", }, Case { @@ -527,7 +613,7 @@ mod test { .await; }) }), - expected_header: b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n", + expected_header: "HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\n", expected_body: b"hello", }, Case { @@ -536,13 +622,30 @@ mod test { Box::pin(async move { let _ = tx .send(BridgeData::Data( - b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n".to_vec(), + b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n".to_vec(), )) .await; let _ = tx.send(BridgeData::Data(b"\r\nhello".to_vec())).await; }) }), - expected_header: b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n", + expected_header: "HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\n", + expected_body: b"hello", + }, + Case { + name: "response 100 then response 200", + send_fn: Box::new(|tx| { + Box::pin(async move { + let _ = tx + .send(BridgeData::Data(b"HTTP/1.1 100 Continue\r\n\r\n".to_vec())) + .await; + let _ = tx + .send(BridgeData::Data( + b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nhello".to_vec(), + )) + .await; + }) + }), + expected_header: "HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\n", expected_body: b"hello", }, ]; @@ -567,8 +670,8 @@ mod test { )); (case.send_fn)(data_tx).await; - let header = header_rx.await.unwrap(); - assert_eq!(header, case.expected_header); + let http_builder = header_rx.await.unwrap().unwrap(); + assert_eq!(http_builder_to_string(http_builder), case.expected_header); let body = body_rx.recv().await.unwrap().unwrap(); assert_eq!( body.into_data().unwrap(), @@ -579,4 +682,20 @@ mod test { remove_bridge_receiver.cancelled().await; } } + + fn http_builder_to_string<'a>(builder: Builder) -> &'a str { + let (parts, _) = builder.body(()).unwrap().into_parts(); + let mut buf = String::new(); + buf.push_str(&format!( + "{:?} {:?} {}\r\n", + parts.version, + parts.status, + parts.status.canonical_reason().unwrap(), + )); + for (key, value) in parts.headers.iter() { + buf.push_str(&format!("{}: {}\r\n", key, value.to_str().unwrap())); + } + buf.push_str("\r\n"); + Box::leak(buf.into_boxed_str()) + } } diff --git a/tests/e2e/file_server/main.go b/tests/e2e/file_server/main.go new file mode 100644 index 0000000..6415c20 --- /dev/null +++ b/tests/e2e/file_server/main.go @@ -0,0 +1,65 @@ +package main + +import ( + "fmt" + "io" + "log" + "net/http" + "os" + + "github.com/spf13/pflag" +) + +var ( + port int + host string +) + +func main() { + pflag.IntVar(&port, "port", 8080, "") + pflag.StringVar(&host, "host", "127.0.0.1", "") + pflag.Parse() + + mux := http.NewServeMux() + + mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { + // accept multipart form like + // file1=@/path/to/large_file1.txt;filename=/tmp/dest_large_file1.txt + // file2=@/path/to/large_file2.txt;filename=/tmp/dest_large_file2.txt + // save the file to the filename + if err := r.ParseMultipartForm(200 * (2 << 20)); err != nil { // 200MB + http.Error(w, "Unable to parse multipart form", http.StatusInternalServerError) + return + } + for _, headers := range r.MultipartForm.File { + for _, header := range headers { + file, err := header.Open() + if err != nil { + http.Error(w, "Unable to open file", http.StatusInternalServerError) + return + } + defer file.Close() + + // Create destination file + dst, err := os.Create("/tmp/" + header.Filename) + if err != nil { + http.Error(w, "Unable to create destination file", http.StatusInternalServerError) + return + } + + n, err := io.Copy(dst, file) + if err != nil { + http.Error(w, "Unable to copy file content", http.StatusInternalServerError) + return + } + dst.Close() + log.Printf("uploaded %s with size %d", header.Filename, n) + } + } + fmt.Fprint(w, "OK") + }) + + server := &http.Server{Addr: fmt.Sprintf("%s:%d", host, port), Handler: mux} + println(server.Addr) + panic(server.ListenAndServe()) +} diff --git a/tests/e2e/test_upload_file_through_http_tunnel.sh b/tests/e2e/test_upload_file_through_http_tunnel.sh new file mode 100755 index 0000000..0332f47 --- /dev/null +++ b/tests/e2e/test_upload_file_through_http_tunnel.sh @@ -0,0 +1,42 @@ +#!/bin/bash +set -x + +root_dir=$(git rev-parse --show-toplevel) +cur_dir=$root_dir/tests/e2e +source $cur_dir/util.sh + +rm -f /tmp/test_file*.txt + +cleanup() { + echo "Cleaning up..." + kill -SIGINT $server_pid + kill -SIGINT $client_pid + kill $file_server_pid +} + +trap cleanup EXIT + +exec $root_dir/target/debug/castled & +server_pid=$! +wait_port 6610 + +RUST_LOG=DEBUG exec $root_dir/target/debug/castle http 13346 --remote-port 6890 & +client_pid=$! +wait_port 6890 + +RUST_LOG=DEBUG exec $cur_dir/.bin/file_server --port 13346 & +file_server_pid=$! + +dd if=/dev/zero of=/tmp/test_file1.txt bs=1K count=2 #(2K) +dd if=/dev/zero of=/tmp/test_file3.txt bs=1M count=100 #(100M) + +response_code=$(curl -s -o /dev/null -w "%{http_code}" -X POST -F "file=@/tmp/test_file1.txt;filename=/tmp/test_file2.txt" -F "file=@/tmp/test_file3.txt;filename=/tmp/test_file4.txt" http://localhost:6890/upload?dest=/tmp/test_file2.txt) +if [ $response_code -eq 200 ]; then + echo "Response code is 200" +else + echo "Response code is $response_code" + exit 1 +fi + +diff /tmp/test_file1.txt /tmp/test_file2.txt +diff /tmp/test_file3.txt /tmp/test_file4.txt diff --git a/tests/e2e/util.sh b/tests/e2e/util.sh index 52489a5..764972e 100644 --- a/tests/e2e/util.sh +++ b/tests/e2e/util.sh @@ -3,6 +3,7 @@ cd "$(dirname "$0")" go build -o .bin/ ./ping/ping.go +go build -o .bin/file_server ./file_server/main.go wait_port() { local port=$1