diff --git a/build.sh b/build.sh index c2334b040..4af5aaedd 100755 --- a/build.sh +++ b/build.sh @@ -22,29 +22,35 @@ root=$(pwd) [ ! -d uploads ] && mkdir uploads # Separate arrays for target architectures and Docker images -target_architectures=("x86_64-unknown-linux-musl" "aarch64-unknown-linux-musl" "armv7-unknown-linux-musleabi" "armv7-unknown-linux-musleabihf" "arm-unknown-linux-musleabi" "arm-unknown-linux-musleabihf" "armv5te-unknown-linux-musleabi" "i686-unknown-linux-gnu" "i586-unknown-linux-gnu" "x86_64-pc-windows-msvc") -docker_images=("ghcr.io/gngpp/rust-musl-cross:x86_64-musl" "ghcr.io/gngpp/rust-musl-cross:aarch64-musl" "ghcr.io/gngpp/rust-musl-cross:armv7-musleabi" "ghcr.io/gngpp/rust-musl-cross:armv7-musleabihf" "ghcr.io/gngpp/rust-musl-cross:arm-musleabi" "ghcr.io/gngpp/rust-musl-cross:arm-musleabihf" "ghcr.io/gngpp/rust-musl-cross:armv5te-musleabi" "ghcr.io/gngpp/rust-musl-cross:i686-musl" "ghcr.io/gngpp/rust-musl-cross:i586-musl" "ghcr.io/gngpp/cargo-xwin:latest") - -get_docker_image() { - local target_arch="$1" - local index - for ((index = 0; index < ${#target_architectures[@]}; ++index)); do - if [ "${target_architectures[index]}" == "$target_arch" ]; then - echo "${docker_images[index]}" - return 0 - fi - done - - echo "Architecture not found" - return 1 +target_architectures=( + "x86_64-unknown-linux-musl" + "aarch64-unknown-linux-musl" + "armv7-unknown-linux-musleabi" + "armv7-unknown-linux-musleabihf" + "arm-unknown-linux-musleabi" + "arm-unknown-linux-musleabihf" + "armv5te-unknown-linux-musleabi" + "i686-unknown-linux-gnu" + "i586-unknown-linux-gnu" + "x86_64-pc-windows-msvc" +) + +pull_docker_image() { + image="ghcr.io/gngpp/rust-musl-cross:$1" + echo "Pulling $image" + docker pull $image } rmi_docker_image() { - echo "Removing $1" - docker rmi $1 + image="ghcr.io/gngpp/rust-musl-cross:$1" + echo "Removing $image docker image" + if [ "$rmi" = "true" ]; then + docker rmi $image + fi } build_macos_target() { + echo "Building $1" cargo build --release --target $1 --features mimalloc sudo chmod -R 777 target cd target/$1/release @@ -56,6 +62,8 @@ build_macos_target() { } build_linux_target() { + docker_image="ghcr.io/gngpp/rust-musl-cross:$1" + features="" if [ "$1" = "armv5te-unknown-linux-musleabi" ] || [ "$1" = "arm-unknown-linux-musleabi" ] || [ "$1" = "arm-unknown-linux-musleabihf" ]; then features="--features rpmalloc" @@ -67,8 +75,7 @@ build_linux_target() { fi fi - docker_image=$(get_docker_image "$1") - + echo "Building $1" docker run --rm -t --user=$UID:$(id -g $USER) \ -v $(pwd):/home/rust/src \ -v $HOME/.cargo/registry:/root/.cargo/registry \ @@ -90,8 +97,9 @@ build_linux_target() { } build_windows_target() { - docker_image=$(get_docker_image "$1") + docker_image="ghcr.io/gngpp/rust-musl-cross:$1" + echo "Building $1" docker run --rm -t \ -v $(pwd):/home/rust/src \ -v $HOME/.cargo/registry:/usr/local/cargo/registry \ @@ -109,48 +117,40 @@ build_windows_target() { } if [ "$os" = "windows" ]; then - target_list=(x86_64-pc-windows-msvc) - for target in "${target_list[@]}"; do - echo "Building $target" - - docker_image=$(get_docker_image "$target") - - build_windows_target "$target" - - if [ "$rmi" = "true" ]; then - rmi_docker_image "$docker_image" - fi - done + target="x86_64-pc-windows-msvc" + pull_docker_image "$target" + build_windows_target "$target" + rmi_docker_image "$target" fi if [ "$os" = "linux" ]; then - target_list=(x86_64-unknown-linux-musl aarch64-unknown-linux-musl armv7-unknown-linux-musleabi armv7-unknown-linux-musleabihf armv5te-unknown-linux-musleabi arm-unknown-linux-musleabi arm-unknown-linux-musleabihf i686-unknown-linux-gnu i586-unknown-linux-gnu) + target_list=( + "x86_64-unknown-linux-musl" + "aarch64-unknown-linux-musl" + "armv7-unknown-linux-musleabi" + "armv7-unknown-linux-musleabihf" + "armv5te-unknown-linux-musleabi" + "arm-unknown-linux-musleabi" + "arm-unknown-linux-musleabihf" + "i686-unknown-linux-gnu" + "i586-unknown-linux-gnu" + ) for target in "${target_list[@]}"; do - echo "Building $target" - - docker_image=$(get_docker_image "$target") - - if [ "$target" = "x86_64-pc-windows-msvc" ]; then - build_windows_target "$target" - else - build_linux_target "$target" - fi - - if [ "$rmi" = "true" ]; then - rmi_docker_image "$docker_image" - fi + pull_docker_image "$target" + build_linux_target "$target" + rmi_docker_image "$target" done fi if [ "$os" = "macos" ]; then - if ! which upx &>/dev/null; then - brew install upx - fi - rustup target add x86_64-apple-darwin aarch64-apple-darwin - target_list=(x86_64-apple-darwin aarch64-apple-darwin) + target_list=( + "x86_64-apple-darwin" + "aarch64-apple-darwin" + ) for target in "${target_list[@]}"; do - echo "Building $target" + echo "Adding $target to the build queue" + rustup target add "$target" build_macos_target "$target" done fi diff --git a/crates/mitm/Cargo.toml b/crates/mitm/Cargo.toml index 7797f3b3a..791a4b6ac 100644 --- a/crates/mitm/Cargo.toml +++ b/crates/mitm/Cargo.toml @@ -9,7 +9,6 @@ edition = "2021" log = "0.4.20" anyhow = "1.0.75" thiserror = "1.0.48" -async-trait = "0.1.73" reqwest = { package = "reqwest-impersonate", version ="0.11.49", default-features = false, features = [ "boring-tls", "impersonate", "stream", "socks" ] } diff --git a/crates/mitm/src/proxy/handler.rs b/crates/mitm/src/proxy/handler.rs index 7057094f4..29d9a32e5 100644 --- a/crates/mitm/src/proxy/handler.rs +++ b/crates/mitm/src/proxy/handler.rs @@ -1,17 +1,15 @@ -use async_trait::async_trait; use hyper::{Body, Request, Response}; use std::sync::{Arc, RwLock}; use wildmatch::WildMatch; use super::mitm::RequestOrResponse; -#[async_trait] pub trait HttpHandler: Clone + Send + Sync + 'static { - async fn handle_request(&self, req: Request) -> RequestOrResponse { + fn handle_request(&self, req: Request) -> RequestOrResponse { RequestOrResponse::Request(req) } - async fn handle_response(&self, res: Response) -> Response { + fn handle_response(&self, res: Response) -> Response { res } } diff --git a/crates/mitm/src/proxy/mitm.rs b/crates/mitm/src/proxy/mitm.rs index b036dd05f..25ea4a54b 100644 --- a/crates/mitm/src/proxy/mitm.rs +++ b/crates/mitm/src/proxy/mitm.rs @@ -93,7 +93,7 @@ where }; // Proxy request - let mut req = match self.http_handler.handle_request(req).await { + let mut req = match self.http_handler.handle_request(req) { RequestOrResponse::Request(req) => req, RequestOrResponse::Response(res) => return Ok(res), }; @@ -116,7 +116,7 @@ where } }; - let mut res = self.http_handler.handle_response(res).await; + let mut res = self.http_handler.handle_response(res); let length = res.size_hint().lower(); { diff --git a/crates/openai/Cargo.toml b/crates/openai/Cargo.toml index 81939a0f6..b153ef1c3 100644 --- a/crates/openai/Cargo.toml +++ b/crates/openai/Cargo.toml @@ -2,6 +2,7 @@ name = "openai" version = "0.9.26" edition = "2021" +rust-version = "1.75.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -19,11 +20,9 @@ tokio = { version = "1.35.1", features = ["fs", "sync", "signal", "rt-multi-thre serde_json = "1.0.107" serde = {version = "1.0.188", features = ["derive"] } regex = "1.9.5" -async-recursion = "1.0.5" url = { version = "2.5.0", features = ["serde"] } base64 = "0.21.4" rand = "0.8.5" -async-trait = "0.1.73" typed-builder = "0.18.0" jsonwebtokens = "1.2.0" sha2 = "0.10.7" diff --git a/crates/openai/src/auth/mod.rs b/crates/openai/src/auth/mod.rs index 68865667f..5d7ff6459 100644 --- a/crates/openai/src/auth/mod.rs +++ b/crates/openai/src/auth/mod.rs @@ -18,7 +18,7 @@ use serde::de::DeserializeOwned; use base64::{engine::general_purpose, Engine as _}; use rand::Rng; -use reqwest::{Client, Proxy, StatusCode, Url}; +use reqwest::{Client, ClientBuilder, Proxy, StatusCode, Url}; use sha2::{Digest, Sha256}; use tokio::sync::OnceCell; @@ -30,6 +30,7 @@ use error::AuthError; use self::model::{ApiKeyData, AuthStrategy}; #[cfg(feature = "preauth")] use self::provide::apple::AppleAuthProvider; +use self::provide::apple::PreAuthProvider; use self::provide::platform::PlatformAuthProvider; use self::provide::web::WebAuthProvider; use self::provide::{AuthProvider, AuthResult}; @@ -48,7 +49,7 @@ static EMAIL_REGEX: OnceCell = OnceCell::const_new(); #[derive(Clone)] pub struct AuthClient { inner: Client, - providers: Arc>>, + providers: Vec, } impl AuthClient { @@ -265,7 +266,6 @@ impl AuthClient { } } -#[async_trait::async_trait] impl AuthProvider for AuthClient { fn supports(&self, t: &AuthStrategy) -> bool { self.providers.iter().any(|strategy| strategy.supports(t)) @@ -338,19 +338,17 @@ impl AuthProvider for AuthClient { } } -pub struct AuthClientBuilder { - inner: reqwest::ClientBuilder, -} +pub struct AuthClientBuilder(ClientBuilder); impl AuthClientBuilder { // Proxy options pub fn proxy(mut self, proxy: Option) -> Self { if let Some(url) = proxy { - self.inner = self - .inner + self.0 = self + .0 .proxy(Proxy::all(url).expect("reqwest: invalid proxy url")); } else { - self.inner = self.inner.no_proxy(); + self.0 = self.0.no_proxy(); } self } @@ -364,7 +362,7 @@ impl AuthClientBuilder { /// /// Default is no timeout. pub fn timeout(mut self, timeout: Duration) -> Self { - self.inner = self.inner.timeout(timeout); + self.0 = self.0.timeout(timeout); self } @@ -377,7 +375,7 @@ impl AuthClientBuilder { /// This **requires** the futures be executed in a tokio runtime with /// a tokio timer enabled. pub fn connect_timeout(mut self, timeout: Duration) -> Self { - self.inner = self.inner.connect_timeout(timeout); + self.0 = self.0.connect_timeout(timeout); self } @@ -392,13 +390,13 @@ impl AuthClientBuilder { where D: Into>, { - self.inner = self.inner.pool_idle_timeout(val); + self.0 = self.0.pool_idle_timeout(val); self } /// Sets the maximum idle connection per host allowed in the pool. pub fn pool_max_idle_per_host(mut self, max: usize) -> Self { - self.inner = self.inner.pool_max_idle_per_host(max); + self.0 = self.0.pool_max_idle_per_host(max); self } @@ -409,19 +407,19 @@ impl AuthClientBuilder { where D: Into>, { - self.inner = self.inner.tcp_keepalive(val); + self.0 = self.0.tcp_keepalive(val); self } /// Sets the necessary values to mimic the specified impersonate client version. pub fn impersonate(mut self, ver: Impersonate) -> Self { - self.inner = self.inner.impersonate(ver); + self.0 = self.0.impersonate(ver); self } /// Sets the `User-Agent` header to be used by this client. pub fn user_agent(mut self, value: &str) -> Self { - self.inner = self.inner.user_agent(value); + self.0 = self.0.user_agent(value); self } @@ -430,14 +428,14 @@ impl AuthClientBuilder { where T: Into>, { - self.inner = self.inner.local_address(addr); + self.0 = self.0.local_address(addr); self } /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's /// preferences) before connection. pub fn local_addresses(mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) -> Self { - self.inner = self.inner.local_addresses(addr_ipv4, addr_ipv6); + self.0 = self.0.local_addresses(addr_ipv4, addr_ipv6); self } @@ -447,31 +445,31 @@ impl AuthClientBuilder { /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will /// still be applied on top of this resolver. pub fn dns_resolver(mut self, resolver: Arc) -> Self { - self.inner = self.inner.dns_resolver(resolver); + self.0 = self.0.dns_resolver(resolver); self } /// Controls the use of certificate validation. pub fn danger_accept_invalid_certs(mut self, enable: bool) -> Self { - self.inner = self.inner.danger_accept_invalid_certs(enable); + self.0 = self.0.danger_accept_invalid_certs(enable); self } /// Enable Encrypted Client Hello (Secure SNI) pub fn enable_ech_grease(mut self, enable: bool) -> Self { - self.inner = self.inner.enable_ech_grease(enable); + self.0 = self.0.enable_ech_grease(enable); self } /// Enable TLS permute_extensions pub fn permute_extensions(mut self, enable: bool) -> Self { - self.inner = self.inner.permute_extensions(enable); + self.0 = self.0.permute_extensions(enable); self } pub fn build(self) -> AuthClient { let client = self - .inner + .0 .default_headers({ let mut headers = HeaderMap::new(); headers.insert(header::ORIGIN, HeaderValue::from_static(OPENAI_OAUTH_URL)); @@ -481,21 +479,83 @@ impl AuthClientBuilder { .build() .expect("ClientBuilder::build()"); - let mut providers: Vec> = Vec::with_capacity(3); - providers.push(Box::new(WebAuthProvider::new(client.clone()))); - providers.push(Box::new(PlatformAuthProvider::new(client.clone()))); + let mut providers = Vec::with_capacity(3); + + // Web Login privider + providers.push(AuthProviderContext::Web(WebAuthProvider(client.clone()))); + + // Apple Login privider #[cfg(feature = "preauth")] - providers.push(Box::new(AppleAuthProvider::new(client.clone()))); + providers.push(AuthProviderContext::Apple(AppleAuthProvider { + inner: client.clone(), + preauth_provider: PreAuthProvider, + })); + + // Platform Login privider + providers.push(AuthProviderContext::Platform(PlatformAuthProvider( + client.clone(), + ))); AuthClient { inner: client, - providers: Arc::new(providers), + providers, } } pub fn builder() -> AuthClientBuilder { - AuthClientBuilder { - inner: Client::builder().redirect(Policy::none()), + AuthClientBuilder(Client::builder().redirect(Policy::none())) + } +} + +#[derive(Clone)] +pub(crate) enum AuthProviderContext { + Web(WebAuthProvider), + #[cfg(feature = "preauth")] + Apple(AppleAuthProvider), + Platform(PlatformAuthProvider), +} + +impl AuthProvider for AuthProviderContext { + fn supports(&self, t: &AuthStrategy) -> bool { + match self { + AuthProviderContext::Web(provider) => provider.supports(t), + #[cfg(feature = "preauth")] + AuthProviderContext::Apple(provider) => provider.supports(t), + AuthProviderContext::Platform(provider) => provider.supports(t), + } + } + + async fn do_access_token( + &self, + account: &model::AuthAccount, + ) -> AuthResult { + match self { + AuthProviderContext::Web(provider) => provider.do_access_token(account).await, + #[cfg(feature = "preauth")] + AuthProviderContext::Apple(provider) => provider.do_access_token(account).await, + AuthProviderContext::Platform(provider) => provider.do_access_token(account).await, + } + } + + async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()> { + match self { + AuthProviderContext::Web(provider) => provider.do_revoke_token(refresh_token).await, + #[cfg(feature = "preauth")] + AuthProviderContext::Apple(provider) => provider.do_revoke_token(refresh_token).await, + AuthProviderContext::Platform(provider) => { + provider.do_revoke_token(refresh_token).await + } + } + } + + async fn do_refresh_token(&self, refresh_token: &str) -> AuthResult { + match self { + AuthProviderContext::Web(provider) => provider.do_refresh_token(refresh_token).await, + #[cfg(feature = "preauth")] + AuthProviderContext::Apple(provider) => provider.do_refresh_token(refresh_token).await, + AuthProviderContext::Platform(provider) => { + provider.do_refresh_token(refresh_token).await + } } } } diff --git a/crates/openai/src/auth/provide/apple.rs b/crates/openai/src/auth/provide/apple.rs index 72671f57d..c4a49f687 100644 --- a/crates/openai/src/auth/provide/apple.rs +++ b/crates/openai/src/auth/provide/apple.rs @@ -6,7 +6,6 @@ use crate::auth::{ OPENAI_OAUTH_REVOKE_URL, OPENAI_OAUTH_TOKEN_URL, OPENAI_OAUTH_URL, }; use crate::{warn, with_context}; -use async_recursion::async_recursion; use axum::http::HeaderValue; use reqwest::Client; use url::Url; @@ -23,6 +22,7 @@ const APPLE_CLIENT_ID: &str = "pdlLIX2Y72MIl2rhLhTE9VV9bN905kBh"; const OPENAI_OAUTH_APPLE_CALLBACK_URL: &str = "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback"; +#[derive(Clone)] pub(crate) struct PreAuthProvider; impl PreAuthProvider { @@ -31,19 +31,13 @@ impl PreAuthProvider { } } +#[derive(Clone)] pub(crate) struct AppleAuthProvider { - inner: Client, - preauth_provider: PreAuthProvider, + pub inner: Client, + pub preauth_provider: PreAuthProvider, } impl AppleAuthProvider { - pub fn new(inner: Client) -> impl AuthProvider + Send + Sync { - Self { - inner, - preauth_provider: PreAuthProvider, - } - } - async fn authorize(&self, ctx: &mut RequestContext<'_>) -> AuthResult<()> { // Get the preauth cookie. let preauth_cookie = self.preauth_provider.get_preauth_cookie()?; @@ -199,7 +193,6 @@ impl AppleAuthProvider { Err(AuthError::FailedCallbackURL) } - #[async_recursion] async fn authenticate_mfa( &self, ctx: &mut RequestContext<'_>, @@ -236,7 +229,33 @@ impl AppleAuthProvider { return Err(AuthError::MFAFailed); } - self.authenticate_resume(ctx, location).await + let resp = self + .inner + .get(&format!("{OPENAI_OAUTH_URL}{location}")) + .ext_context(ctx) + .send() + .await + .map_err(AuthError::FailedRequest)? + .ext_context(ctx); + + // If resp status is client error return InvalidEmailOrPassword + if resp.status().is_client_error() { + return Err(AuthError::InvalidEmailOrPassword); + } + + // maybe auth failed + let _ = AuthClient::check_auth_callback_state(resp.url())?; + + // If get_location_path returns an error, it means that the location is invalid. + let location: &str = AuthClient::get_location_path(&resp.headers())?; + + // Indicates successful login. + if location.starts_with(OPENAI_OAUTH_APPLE_CALLBACK_URL) { + return self.authorization_code(ctx, location).await; + } + + // Return an error if the location is invalid. + Err(AuthError::FailedCallbackURL) } async fn authorization_code( @@ -274,7 +293,6 @@ impl AppleAuthProvider { } } -#[async_trait::async_trait] impl AuthProvider for AppleAuthProvider { fn supports(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Apple) diff --git a/crates/openai/src/auth/provide/mod.rs b/crates/openai/src/auth/provide/mod.rs index 580d28bf0..9565bff54 100644 --- a/crates/openai/src/auth/provide/mod.rs +++ b/crates/openai/src/auth/provide/mod.rs @@ -20,8 +20,8 @@ use typed_builder::TypedBuilder; pub type AuthResult = anyhow::Result; -#[async_trait::async_trait] -pub trait AuthProvider: Send + Sync { +#[allow(async_fn_in_trait)] +pub trait AuthProvider { async fn do_access_token(&self, account: &model::AuthAccount) -> AuthResult; diff --git a/crates/openai/src/auth/provide/platform.rs b/crates/openai/src/auth/provide/platform.rs index 1c35d7256..2e5b6a51b 100644 --- a/crates/openai/src/auth/provide/platform.rs +++ b/crates/openai/src/auth/provide/platform.rs @@ -6,7 +6,6 @@ use crate::auth::{ OPENAI_OAUTH_REVOKE_URL, OPENAI_OAUTH_TOKEN_URL, OPENAI_OAUTH_URL, }; use crate::warn; -use async_recursion::async_recursion; use axum::http::HeaderValue; use reqwest::Client; use url::Url; @@ -19,13 +18,10 @@ use super::{ const PLATFORM_CLIENT_ID: &str = "DRivsnm2Mu42T3KOpqdtwB3NYviHYzwD"; const OPENAI_OAUTH_PLATFORM_CALLBACK_URL: &str = "https://platform.openai.com/auth/callback"; -pub(crate) struct PlatformAuthProvider(pub(crate) Client); +#[derive(Clone)] +pub(crate) struct PlatformAuthProvider(pub Client); impl PlatformAuthProvider { - pub fn new(inner: Client) -> impl AuthProvider + Send + Sync { - Self(inner) - } - async fn authorize(&self, ctx: &mut RequestContext<'_>) -> AuthResult<()> { // Build url let url = format!("{OPENAI_OAUTH_URL}/authorize?client_id={PLATFORM_CLIENT_ID}&scope=openid%20email%20profile%20offline_access%20model.request%20model.read%20organization.read%20organization.write&audience=https://api.openai.com/v1&redirect_uri=https://platform.openai.com/auth/callback&response_type=code"); @@ -165,7 +161,6 @@ impl PlatformAuthProvider { Err(AuthError::FailedCallbackURL) } - #[async_recursion] async fn authenticate_mfa( &self, ctx: &mut RequestContext<'_>, @@ -211,7 +206,28 @@ impl PlatformAuthProvider { if location.starts_with("/authorize/resume?") && ctx.account.mfa.is_none() { return Err(AuthError::MFAFailed); } - self.authenticate_resume(ctx, location).await + + let resp = self + .0 + .get(&format!("{OPENAI_OAUTH_URL}{location}")) + .ext_context(ctx) + .send() + .await + .map_err(AuthError::FailedRequest)? + .ext_context(ctx); + + // maybe auth failed + let _ = AuthClient::check_auth_callback_state(resp.url())?; + + // Get location path + let location: &str = AuthClient::get_location_path(&resp.headers())?; + + // If location path starts with https://platform.openai.com/auth/callback + if location.starts_with(OPENAI_OAUTH_PLATFORM_CALLBACK_URL) { + return self.authorization_code(location).await; + } + + Err(AuthError::FailedCallbackURL) } async fn authorization_code(&self, location: &str) -> AuthResult { @@ -242,7 +258,6 @@ impl PlatformAuthProvider { } } -#[async_trait::async_trait] impl AuthProvider for PlatformAuthProvider { fn supports(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Platform) diff --git a/crates/openai/src/auth/provide/web.rs b/crates/openai/src/auth/provide/web.rs index f9abc9c61..3335fa937 100644 --- a/crates/openai/src/auth/provide/web.rs +++ b/crates/openai/src/auth/provide/web.rs @@ -13,13 +13,10 @@ use reqwest::{Client, StatusCode}; use serde_json::Value; use url::Url; -pub(crate) struct WebAuthProvider(pub(crate) Client); +#[derive(Clone)] +pub(crate) struct WebAuthProvider(pub Client); impl WebAuthProvider { - pub fn new(inner: Client) -> impl AuthProvider + Send + Sync { - Self(inner) - } - async fn csrf_token(&self, ctx: &mut RequestContext<'_>) -> AuthResult<()> { let resp = self .0 @@ -290,7 +287,6 @@ impl WebAuthProvider { } } -#[async_trait::async_trait] impl AuthProvider for WebAuthProvider { fn supports(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Web) diff --git a/crates/openai/src/lib.rs b/crates/openai/src/lib.rs index 7dd09c037..25d6111bc 100644 --- a/crates/openai/src/lib.rs +++ b/crates/openai/src/lib.rs @@ -1,4 +1,3 @@ -#![recursion_limit = "256"] pub mod arkose; pub mod auth; pub mod chatgpt; diff --git a/crates/openai/src/serve/middleware/tokenbucket.rs b/crates/openai/src/serve/middleware/tokenbucket.rs index e3a069b0e..77c587bcb 100644 --- a/crates/openai/src/serve/middleware/tokenbucket.rs +++ b/crates/openai/src/serve/middleware/tokenbucket.rs @@ -7,7 +7,6 @@ use std::time::Duration; use crate::homedir::home_dir; use crate::{context, debug, error, now_duration}; -#[async_trait::async_trait] pub trait TokenBucket: Send + Sync { async fn acquire(&self, ip: IpAddr) -> anyhow::Result; } @@ -68,7 +67,6 @@ impl MemTokenBucket { } } -#[async_trait::async_trait] impl TokenBucket for MemTokenBucket { async fn acquire(&self, ip: IpAddr) -> anyhow::Result { if !self.enable { @@ -212,8 +210,7 @@ fn clear_expired_buckets_every(db: Arc>, expired: u32) { }); } -#[async_trait::async_trait] -impl<'a> TokenBucket for RedisTokenBucket<'_> { +impl TokenBucket for RedisTokenBucket<'_> { async fn acquire(&self, ip: IpAddr) -> anyhow::Result { if !self.enable { return Ok(true); @@ -278,25 +275,27 @@ fn ip_to_number(ip: IpAddr) -> u128 { } } -pub struct TokenBucketLimitContext(Box); +pub enum TokenBucketLimitContext { + Mem(MemTokenBucket), + ReDB(RedisTokenBucket<'static>), +} impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketLimitContext { fn from(value: (Strategy, bool, u32, u32, u32)) -> Self { let strategy = match value.0 { - Strategy::Mem => Self(Box::new(MemTokenBucket::new( - value.1, value.2, value.3, value.4, - ))), - Strategy::ReDB => Self(Box::new(RedisTokenBucket::new( - value.1, value.2, value.3, value.4, - ))), + Strategy::Mem => Self::Mem(MemTokenBucket::new(value.1, value.2, value.3, value.4)), + Strategy::ReDB => Self::ReDB(RedisTokenBucket::new(value.1, value.2, value.3, value.4)), }; strategy } } -#[async_trait::async_trait] impl TokenBucket for TokenBucketLimitContext { async fn acquire(&self, ip: IpAddr) -> anyhow::Result { - Ok(self.0.acquire(ip).await?) + let condition = match self { + Self::Mem(t) => t.acquire(ip).await, + Self::ReDB(t) => t.acquire(ip).await, + }; + Ok(condition?) } } diff --git a/crates/openai/src/serve/preauth.rs b/crates/openai/src/serve/preauth.rs index 32a2ac44b..a62d7f130 100644 --- a/crates/openai/src/serve/preauth.rs +++ b/crates/openai/src/serve/preauth.rs @@ -10,16 +10,15 @@ use std::fmt::Write; #[derive(Clone)] pub struct PreAuthHanlder; -#[async_trait::async_trait] impl HttpHandler for PreAuthHanlder { - async fn handle_request(&self, req: Request) -> RequestOrResponse { - log_req(&req).await; + fn handle_request(&self, req: Request) -> RequestOrResponse { + log_req(&req); collect_preauth_cookie(req.headers()); RequestOrResponse::Request(req) } - async fn handle_response(&self, res: Response) -> Response { - log_res(&res).await; + fn handle_response(&self, res: Response) -> Response { + log_res(&res); collect_preauth_cookie(res.headers()); res } @@ -36,7 +35,7 @@ fn collect_preauth_cookie(headers: &HeaderMap) { } } -pub async fn log_req(req: &Request) { +pub fn log_req(req: &Request) { let headers = req.headers(); let mut header_formated = String::new(); for (key, value) in headers { @@ -65,7 +64,7 @@ Headers: ) } -pub async fn log_res(res: &Response) { +pub fn log_res(res: &Response) { let headers = res.headers(); let mut header_formated = String::new(); for (key, value) in headers { diff --git a/patches b/patches index b86fb30bd..0137ea3c2 160000 --- a/patches +++ b/patches @@ -1 +1 @@ -Subproject commit b86fb30bde5e0619a1e37cbec2562479fedf8c01 +Subproject commit 0137ea3c2252a17dc07da97337518909d8e1ac3b diff --git a/src/daemon.rs b/src/daemon.rs index f3381aa55..557b57b0f 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -285,6 +285,7 @@ pub(super) fn generate_template(out: Option) -> anyhow::Result<()> { tb_expired: 86400, cookie_store: true, pool_idle_timeout: 90, + arkose_solver_limit: 3, level: "info".to_owned(), pcert: PathBuf::from("ca/cert.crt"), pkey: PathBuf::from("ca/key.pem"),