Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(preauth): Reconstruct PreAuth service-related MITM modules #320

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,4 @@ uploads
./vendor
upgrade
ca
chat3.openai.com.har
chat4.openai.com.har
platform.openai.com.har
login.chat.openai.com.har
4.har
auth0.openai.com_Archive.har
auth0.openai.com.har
*.har
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ anyhow = "1.0.75"
clap = { version = "4.4.3", features = ["derive", "env"] }
serde = {version = "1.0.188", features = ["derive"] }
openai = { path = "openai" }
mitm ={ path = "mitm", optional = true }
cidr = "0.2.2"
toml = "0.8.0"
url = "2.4.1"
Expand Down Expand Up @@ -57,7 +58,8 @@ env_logger = "0.10.0"
openai = { path = "openai" }

[features]
default = ["serve"]
default = ["serve", "mitm"]
mitm = ["openai/preauth", "dep:mitm"]
terminal = [
"openai/api",
"dep:tokio",
Expand Down
29 changes: 29 additions & 0 deletions mitm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[package]
name = "mitm"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
log = "0.4.20"
anyhow = "1.0.75"
thiserror = "1.0.48"
async-trait = "0.1.73"
reqwest = { package = "reqwest-impersonate", version ="0.11.30", default-features = false, features = [
"boring-tls", "impersonate", "stream", "socks"
] }
typed-builder = "0.18.0"
time = "0.3.30"
rand = "0.8.5"
moka = { version = "0.12.1", default-features = false, features = ["sync"] }
tokio = { version = "1.15.0", default-features = false }
rcgen = { version = "0.10", features = ["x509-parser"] }
hyper = { version = "0.14.27", default-features = false }
tokio-rustls = { version = "0.24.1", default-features = false, features = ["tls12"] }
rustls = { version = "0.21.8", features = ["dangerous_configuration"] }
wildmatch = "2.1"
http = "0.2.11"
pin-project = "1"
byteorder = "1.4"
rustls-pemfile = "1.0"
5 changes: 3 additions & 2 deletions openai/src/serve/preauth/cagen.rs → mitm/src/cagen.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use log::error;
use rcgen::Certificate;

use crate::{error, serve::preauth::proxy::CertificateAuthority};

use std::fs;

use crate::proxy::CertificateAuthority;

pub fn gen_ca() -> Certificate {
let cert = CertificateAuthority::gen_ca().expect("preauth generate cert");
let cert_crt = cert.serialize_pem().unwrap();
Expand Down
60 changes: 60 additions & 0 deletions mitm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
pub mod cagen;
pub mod proxy;

use anyhow::Context;
use std::{fs, net::SocketAddr, path::PathBuf};
use typed_builder::TypedBuilder;

use crate::proxy::{handler::HttpHandler, CertificateAuthority};
use log::info;

#[derive(TypedBuilder)]
pub struct Builder<T: HttpHandler + Clone> {
bind: SocketAddr,
upstream_proxy: Option<String>,
cert: PathBuf,
key: PathBuf,
graceful_shutdown: tokio::sync::mpsc::Receiver<()>,
cerificate_cache_size: u32,
mitm_filters: Vec<String>,
handler: T,
}

impl<T: HttpHandler + Clone> Builder<T> {
pub async fn mitm_proxy(self) -> anyhow::Result<()> {
info!("PreAuth CA Private key use: {}", self.key.display());
let private_key_bytes =
fs::read(self.key).context("ca private key file path not valid!")?;
let private_key = rustls_pemfile::pkcs8_private_keys(&mut private_key_bytes.as_slice())
.context("Failed to parse private key")?;
let key = rustls::PrivateKey(private_key[0].clone());

info!("PreAuth CA Certificate use: {}", self.cert.display());
let ca_cert_bytes = fs::read(self.cert).context("ca cert file path not valid!")?;
let ca_cert = rustls_pemfile::certs(&mut ca_cert_bytes.as_slice())
.context("Failed to parse CA certificate")?;
let cert = rustls::Certificate(ca_cert[0].clone());

let ca = CertificateAuthority::new(
key,
cert,
String::from_utf8(ca_cert_bytes).context("Failed to parse CA certificate")?,
self.cerificate_cache_size.into(),
)
.context("Failed to create Certificate Authority")?;

info!("PreAuth Http MITM Proxy listen on: http://{}", self.bind);

let proxy = proxy::Proxy::builder()
.ca(ca.clone())
.listen_addr(self.bind)
.upstream_proxy(self.upstream_proxy)
.mitm_filters(self.mitm_filters)
.handler(self.handler)
.graceful_shutdown(self.graceful_shutdown)
.build();

tokio::spawn(proxy.start_proxy());
Ok(())
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use http::{response::Builder, Request, Response};
use hyper::Body;
use hyper::{body, Body};
use reqwest::impersonate::Impersonate;

use super::error::Error;
Expand Down Expand Up @@ -33,7 +33,7 @@ impl HttpClient {
.request(method, url)
.headers(parts.headers)
.version(parts.version)
.body(hyper::body::to_bytes(body).await?)
.body(body::to_bytes(body).await?)
.send()
.await?;

Expand All @@ -46,6 +46,6 @@ impl HttpClient {
.headers_mut()
.map(|h| h.extend(resp.headers().clone()));

Ok(builder.body(hyper::body::Body::wrap_stream(resp.bytes_stream()))?)
Ok(builder.body(body::Body::wrap_stream(resp.bytes_stream()))?)
}
}
File renamed without changes.
File renamed without changes.
16 changes: 4 additions & 12 deletions openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ nom = { version = "7.1.3", optional = true }
mime = { version = "0.3.17", optional = true }
futures-timer = { version = "3.0.2", optional = true }

# mitm
mitm = { path = "../mitm", optional = true }

# arkose
aes = "0.8.3"
md5 = "0.7.0"
Expand All @@ -63,17 +66,6 @@ async-stream = { version = "0.3.5", optional = true }
axum_csrf = { version = "0.7.2", features = ["layer"], optional = true }
serde_urlencoded = { version = "0.7.1", optional = true }

# mitm
rcgen = { version = "0.10", features = ["x509-parser"], optional = true }
hyper = { version = "0.14.27", default-features = false, optional = true }
tokio-rustls = { version = "0.24.1", default-features = false, features = ["tls12"], optional = true }
rustls = { version = "0.21.8", features = ["dangerous_configuration"], optional = true }
wildmatch = { version = "2.1", optional = true }
http = { version = "0.2", optional = true }
pin-project = { version = "1", optional = true }
byteorder = { version = "1.4", optional = true }
rustls-pemfile = { version = "1.0", optional = true }

[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.48.0"
default-features = false
Expand All @@ -86,7 +78,7 @@ static-files = "0.2.3"
default = ["serve", "limit", "template", "preauth"]
api = ["stream"]
serve = ["dep:serde_urlencoded", "dep:axum_csrf", "stream", "dep:async-stream", "dep:tracing", "dep:tracing-subscriber", "dep:tower-http", "dep:tower", "dep:bytes", "dep:time", "dep:axum-server", "dep:axum-extra", "dep:axum", "dep:static-files", "dep:futures-core", "dep:tera"]
preauth = ["dep:rustls-pemfile", "dep:rcgen", "dep:moka", "dep:hyper", "dep:tokio-rustls", "dep:rustls", "dep:wildmatch", "dep:http", "dep:pin-project", "dep:byteorder"]
preauth = ["dep:mitm"]
stream = ["dep:tokio-util", "dep:futures", "dep:tokio-stream", "dep:eventsource-stream", "dep:futures-core", "dep:pin-project-lite", "dep:nom", "dep:mime", "dep:futures-timer"]
remote-token = []
limit = ["dep:redis", "dep:redis-macros", "dep:moka"]
Expand Down
3 changes: 2 additions & 1 deletion openai/src/serve/middleware/csrf.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use axum::http::{Method, Request, StatusCode};
use axum::{
body::{self, BoxBody, Full},
middleware::Next,
response::Response,
Form,
};
use axum_csrf::CsrfToken;
use http::{Method, Request, StatusCode};
use mitm::proxy::hyper;

use crate::{auth::model::AuthAccount, warn};

Expand Down
25 changes: 13 additions & 12 deletions openai/src/serve/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod error;
mod middleware;
#[cfg(feature = "preauth")]
pub mod preauth;
mod preauth;
mod proxy;
mod puid;
#[cfg(feature = "template")]
Expand All @@ -14,11 +14,11 @@ use axum::body::Body;
use axum::headers::authorization::Bearer;
use axum::headers::Authorization;
use axum::http::Response;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::{any, get, post};
use axum::{Json, TypedHeader};
use axum_server::{AddrIncomingConfig, Handle};
use http::StatusCode;

use self::proxy::ext::RequestExt;
use self::proxy::ext::SendRequestExt;
Expand Down Expand Up @@ -189,16 +189,17 @@ impl Serve {
// PreAuth mitm proxy
#[cfg(feature = "preauth")]
if let Some(pbind) = self.0.pbind.clone() {
if let Some(err) = preauth::mitm_proxy(
pbind,
self.0.pupstream.clone(),
self.0.pcert.clone(),
self.0.pkey.clone(),
rx,
)
.await
.err()
{
let builder = mitm::Builder::builder()
.bind(pbind)
.upstream_proxy(self.0.pupstream.clone())
.cert(self.0.pcert.clone())
.key(self.0.pkey.clone())
.graceful_shutdown(rx)
.cerificate_cache_size(1_000)
.mitm_filters(vec![String::from("ios.chat.openai.com")])
.handler(preauth::PreAuthHanlder)
.build();
if let Some(err) = builder.mitm_proxy().await.err() {
crate::error!("PreAuth proxy error: {}", err);
}
}
Expand Down
100 changes: 100 additions & 0 deletions openai/src/serve/preauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use mitm::proxy::hyper::{
body::Body,
http::{header, HeaderMap, HeaderValue, Request, Response},
};
use mitm::proxy::{handler::HttpHandler, mitm::RequestOrResponse};
use std::fmt::Write;

use crate::{info, with_context};

#[derive(Clone)]
pub struct PreAuthHanlder;

#[async_trait::async_trait]
impl HttpHandler for PreAuthHanlder {
async fn handle_request(&self, req: Request<Body>) -> RequestOrResponse {
if log::log_enabled!(log::Level::Debug) {
log_req(&req).await;
}
// extract preauth cookie
collect_preauth_cookie(req.headers());
RequestOrResponse::Request(req)
}

async fn handle_response(&self, res: Response<Body>) -> Response<Body> {
if log::log_enabled!(log::Level::Debug) {
log_res(&res).await;
}
collect_preauth_cookie(res.headers());
res
}
}

fn collect_preauth_cookie(headers: &HeaderMap<HeaderValue>) {
headers
.iter()
.filter(|(k, _)| k.eq(&header::COOKIE) || k.eq(&header::SET_COOKIE))
.for_each(|(_, v)| {
let _ = v
.to_str()
.map(|value| with_context!(push_preauth_cookie, value));
});
}

pub async fn log_req(req: &Request<Body>) {
let headers = req.headers();
let mut header_formated = String::new();
for (key, value) in headers {
let v = match value.to_str() {
Ok(v) => v.to_string(),
Err(_) => {
format!("[u8]; {}", value.len())
}
};
write!(
&mut header_formated,
"\t{:<20}{}\r\n",
format!("{}:", key.as_str()),
v
)
.unwrap();
}

info!(
"{} {}
Headers:
{}",
req.method(),
req.uri().to_string(),
header_formated
)
}

pub async fn log_res(res: &Response<Body>) {
let headers = res.headers();
let mut header_formated = String::new();
for (key, value) in headers {
let v = match value.to_str() {
Ok(v) => v.to_string(),
Err(_) => {
format!("[u8]; {}", value.len())
}
};
write!(
&mut header_formated,
"\t{:<20}{}\r\n",
format!("{}:", key.as_str()),
v
)
.unwrap();
}

info!(
"{} {:?}
Headers:
{}",
res.status(),
res.version(),
header_formated
)
}
Loading