Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(websocket): Add proxy configuration #1536

Open
wants to merge 1 commit into
base: v1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions plugins/websocket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ rand = "0.8"
futures-util = "0.3"
tokio = { version = "1", features = ["net", "sync"] }
tokio-tungstenite = { version = "0.23", features = ["native-tls"] }
hyper = { version = "1.4.1", features = ["client"] }
hyper-util = { version = "0.1.6", features = ["tokio", "http1"] }
base64 = "0.22.1"
189 changes: 178 additions & 11 deletions plugins/websocket/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use base64::prelude::{Engine, BASE64_STANDARD};
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use http::header::{HeaderName, HeaderValue};
use http::{
header::{HeaderName, HeaderValue},
Request,
};
use hyper::client::conn;
use hyper_util::rt::TokioIo;
use serde::{ser::Serializer, Deserialize, Serialize};
use tauri::{
api::ipc::{format_callback, CallbackFn},
plugin::{Builder as PluginBuilder, TauriPlugin},
Manager, Runtime, State, Window,
AppHandle, Manager, Runtime, State, Window,
};
use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::{
connect_async_tls_with_config,
client_async_tls_with_config, connect_async_with_config,
tungstenite::{
client::IntoClientRequest,
error::UrlError,
protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
Message,
},
Expand All @@ -22,7 +29,8 @@ use std::str::FromStr;

type Id = u32;
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WebSocketWriter = SplitSink<WebSocket, Message>;
type WebSocketWriter =
SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>, Message>;
type Result<T> = std::result::Result<T, Error>;

#[derive(Debug, thiserror::Error)]
Expand All @@ -35,6 +43,14 @@ enum Error {
InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
#[error(transparent)]
InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
#[error(transparent)]
ProxyConnectionError(#[from] hyper::Error),
#[error("proxy returned status code: {0}")]
ProxyStatusError(u16),
#[error(transparent)]
ProxyIoError(std::io::Error),
#[error(transparent)]
ProxyHttpError(http::Error),
}

impl Serialize for Error {
Expand All @@ -50,6 +66,26 @@ impl Serialize for Error {
struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);

struct TlsConnector(Mutex<Option<Connector>>);
struct ProxyConfigurationInternal(Mutex<Option<ProxyConfiguration>>);

#[derive(Clone)]
pub struct ProxyAuth {
pub username: String,
pub password: String,
}

impl ProxyAuth {
pub fn encode(&self) -> String {
BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password))
}
}

#[derive(Clone)]
pub struct ProxyConfiguration {
pub proxy_url: String,
pub proxy_port: u16,
pub auth: Option<ProxyAuth>,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -105,10 +141,6 @@ async fn connect<R: Runtime>(
) -> Result<Id> {
let id = rand::random();
let mut request = url.into_client_request()?;
let tls_connector = match window.try_state::<TlsConnector>() {
Some(tls_connector) => tls_connector.0.lock().await.clone(),
None => None,
};

if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
for (k, v) in headers {
Expand All @@ -118,9 +150,32 @@ async fn connect<R: Runtime>(
}
}

let (ws_stream, _) =
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
.await?;
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let tls_connector = match window.try_state::<TlsConnector>() {
Some(tls_connector) => tls_connector.0.lock().await.clone(),
None => None,
};
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
let tls_connector = None;

let proxy_config = match window.try_state::<ProxyConfigurationInternal>() {
Some(proxy_config) => proxy_config.0.lock().await.clone(),
None => None,
};

let ws_stream = if let Some(proxy_config) = proxy_config {
connect_using_proxy(request, config, proxy_config, tls_connector).await?
} else {
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let (ws_stream, _) =
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
.await?;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
let (ws_stream, _) =
connect_async_with_config(request, config.map(Into::into), false).await?;

ws_stream
};

tauri::async_runtime::spawn(async move {
let (write, read) = ws_stream.split();
Expand Down Expand Up @@ -168,6 +223,70 @@ async fn connect<R: Runtime>(
Ok(id)
}

async fn connect_using_proxy(
request: Request<()>,
config: Option<ConnectionConfig>,
proxy_config: ProxyConfiguration,
tls_connector: Option<Connector>,
) -> Result<WebSocket> {
let domain = domain(&request)?;
let port = request
.uri()
.port_u16()
.or_else(|| match request.uri().scheme_str() {
Some("wss") => Some(443),
Some("ws") => Some(80),
_ => None,
})
.ok_or(Error::Websocket(
tokio_tungstenite::tungstenite::Error::Url(UrlError::UnsupportedUrlScheme),
))?;

let tcp = TcpStream::connect(format!(
"{}:{}",
proxy_config.proxy_url, proxy_config.proxy_port
))
.await
.map_err(|original| Error::ProxyIoError(original))?;
let io = TokioIo::new(tcp);

let (mut request_sender, proxy_connection) =
conn::http1::handshake::<TokioIo<tokio::net::TcpStream>, String>(io).await?;
let proxy_connection_task = tokio::spawn(proxy_connection.without_shutdown());

let addr = format!("{domain}:{port}");
let mut req_builder = Request::connect(addr);

if let Some(auth) = proxy_config.auth {
req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", auth.encode()));
}

let req = req_builder
.body("".to_string())
.map_err(|orig| Error::ProxyHttpError(orig))?;
let res = request_sender.send_request(req).await?;
if res.status().as_u16() < 200 || res.status().as_u16() >= 300 {
return Err(Error::ProxyStatusError(res.status().as_u16()));
}

// expect is fine since it would only rely panics from within the tokio task (or a cancellation which does not happen)
let proxy_connection = proxy_connection_task
.await
.expect("Panic in tokio task during websocket proxy initialization")?;

let proxy_tcp_wrapper = proxy_connection.io;
let proxied_tcp_socket = proxy_tcp_wrapper.into_inner();
let (ws_stream, _) = client_async_tls_with_config(
request,
proxied_tcp_socket,
config.map(Into::into),
tls_connector,
)
.await?;

Ok(ws_stream)
}

#[tauri::command]
async fn send(
manager: State<'_, ConnectionManager>,
Expand Down Expand Up @@ -200,12 +319,14 @@ pub fn init<R: Runtime>() -> TauriPlugin<R> {
#[derive(Default)]
pub struct Builder {
tls_connector: Option<Connector>,
proxy_configuration: Option<ProxyConfiguration>,
}

impl Builder {
pub fn new() -> Self {
Self {
tls_connector: None,
proxy_configuration: None,
}
}

Expand All @@ -214,14 +335,60 @@ impl Builder {
self
}

pub fn proxy_configuration(mut self, proxy_configuration: ProxyConfiguration) -> Self {
self.proxy_configuration.replace(proxy_configuration);
self
}

pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
PluginBuilder::new("websocket")
.invoke_handler(tauri::generate_handler![connect, send])
.setup(|app| {
app.manage(ConnectionManager::default());
app.manage(TlsConnector(Mutex::new(self.tls_connector)));
app.manage(ProxyConfigurationInternal(Mutex::new(
self.proxy_configuration,
)));

Ok(())
})
.build()
}
}

pub async fn reconfigure_proxy(app: &AppHandle, proxy_config: Option<ProxyConfiguration>) {
if let Some(state) = app.try_state::<ProxyConfigurationInternal>() {
if let Some(proxy_config) = proxy_config {
state.0.lock().await.replace(proxy_config);
} else {
state.0.lock().await.take();
}
}
}

pub async fn reconfigure_tls_connector(app: &AppHandle, tls_connector: Option<Connector>) {
if let Some(state) = app.try_state::<TlsConnector>() {
if let Some(tls_connector) = tls_connector {
state.0.lock().await.replace(tls_connector);
} else {
state.0.lock().await.take();
}
}
}

// Copied from tokio-tungstenite internal function (tokio-tungstenite/src/lib.rs) with the same name
// Get a domain from an URL.
#[inline]
fn domain(
request: &tokio_tungstenite::tungstenite::handshake::client::Request,
) -> tokio_tungstenite::tungstenite::Result<String, tokio_tungstenite::tungstenite::Error> {
match request.uri().host() {
// rustls expects IPv6 addresses without the surrounding [] brackets
#[cfg(feature = "__rustls-tls")]
Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
Some(d) => Ok(d.to_string()),
None => Err(tokio_tungstenite::tungstenite::Error::Url(
tokio_tungstenite::tungstenite::error::UrlError::NoHostName,
)),
}
}
Loading