From 2960a42893fb0a41731abf506acf6088359ddb05 Mon Sep 17 00:00:00 2001 From: gngpp Date: Wed, 7 Feb 2024 12:26:45 +0800 Subject: [PATCH] feat: trait-variant solves async trait dynamic scheduling --- crates/openai/Cargo.toml | 1 + crates/openai/src/auth/mod.rs | 48 +++++++++---------- crates/openai/src/auth/provide/apple.rs | 2 +- crates/openai/src/auth/provide/mod.rs | 10 ++-- crates/openai/src/auth/provide/platform.rs | 2 +- crates/openai/src/auth/provide/web.rs | 2 +- crates/openai/src/serve/middleware/limit.rs | 4 +- .../src/serve/middleware/tokenbucket.rs | 6 +-- crates/openai/src/serve/mod.rs | 4 +- 9 files changed, 42 insertions(+), 37 deletions(-) diff --git a/crates/openai/Cargo.toml b/crates/openai/Cargo.toml index b153ef1c3..a0589708b 100644 --- a/crates/openai/Cargo.toml +++ b/crates/openai/Cargo.toml @@ -69,6 +69,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"], optional = async-stream = { version = "0.3.5", optional = true } axum_csrf = { version = "0.8.0", features = ["layer"], optional = true } serde_urlencoded = { version = "0.7.1", optional = true } +trait-variant = "0.1.1" [target.'cfg(target_family = "unix")'.dependencies] nix = { version = "0.27.1", default-features = false, features = ["user"] } diff --git a/crates/openai/src/auth/mod.rs b/crates/openai/src/auth/mod.rs index 5d7ff6459..554dab077 100644 --- a/crates/openai/src/auth/mod.rs +++ b/crates/openai/src/auth/mod.rs @@ -49,7 +49,7 @@ static EMAIL_REGEX: OnceCell = OnceCell::const_new(); #[derive(Clone)] pub struct AuthClient { inner: Client, - providers: Vec, + providers: Vec, } impl AuthClient { @@ -267,8 +267,8 @@ impl AuthClient { } impl AuthProvider for AuthClient { - fn supports(&self, t: &AuthStrategy) -> bool { - self.providers.iter().any(|strategy| strategy.supports(t)) + fn support(&self, t: &AuthStrategy) -> bool { + self.providers.iter().any(|strategy| strategy.support(t)) } async fn do_access_token( @@ -288,7 +288,7 @@ impl AuthProvider for AuthClient { // Try supported providers for provider in self.providers.iter() { - if provider.supports(&account.option) { + if provider.support(&account.option) { return provider.do_access_token(account).await; } } @@ -299,7 +299,7 @@ impl AuthProvider for AuthClient { async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()> { let mut result: Option> = None; for handle in self.providers.iter() { - if handle.supports(&AuthStrategy::Apple) || handle.supports(&AuthStrategy::Platform) { + if handle.support(&AuthStrategy::Apple) || handle.support(&AuthStrategy::Platform) { let res = handle.do_revoke_token(refresh_token).await; match res { Ok(ok) => { @@ -320,7 +320,7 @@ impl AuthProvider for AuthClient { let mut result: Option> = None; for handle in self.providers.iter() { - if handle.supports(&AuthStrategy::Apple) || handle.supports(&AuthStrategy::Platform) { + if handle.support(&AuthStrategy::Apple) || handle.support(&AuthStrategy::Platform) { let res = handle.do_refresh_token(refresh_token).await; match res { Ok(ok) => { @@ -482,17 +482,17 @@ impl AuthClientBuilder { let mut providers = Vec::with_capacity(3); // Web Login privider - providers.push(AuthProviderContext::Web(WebAuthProvider(client.clone()))); + providers.push(Prividers::Web(WebAuthProvider(client.clone()))); // Apple Login privider #[cfg(feature = "preauth")] - providers.push(AuthProviderContext::Apple(AppleAuthProvider { + providers.push(Prividers::Apple(AppleAuthProvider { inner: client.clone(), preauth_provider: PreAuthProvider, })); // Platform Login privider - providers.push(AuthProviderContext::Platform(PlatformAuthProvider( + providers.push(Prividers::Platform(PlatformAuthProvider( client.clone(), ))); @@ -508,20 +508,20 @@ impl AuthClientBuilder { } #[derive(Clone)] -pub(crate) enum AuthProviderContext { +pub(crate) enum Prividers { Web(WebAuthProvider), #[cfg(feature = "preauth")] Apple(AppleAuthProvider), Platform(PlatformAuthProvider), } -impl AuthProvider for AuthProviderContext { - fn supports(&self, t: &AuthStrategy) -> bool { +impl AuthProvider for Prividers { + fn support(&self, t: &AuthStrategy) -> bool { match self { - AuthProviderContext::Web(provider) => provider.supports(t), + Prividers::Web(provider) => provider.support(t), #[cfg(feature = "preauth")] - AuthProviderContext::Apple(provider) => provider.supports(t), - AuthProviderContext::Platform(provider) => provider.supports(t), + Prividers::Apple(provider) => provider.support(t), + Prividers::Platform(provider) => provider.support(t), } } @@ -530,19 +530,19 @@ impl AuthProvider for AuthProviderContext { account: &model::AuthAccount, ) -> AuthResult { match self { - AuthProviderContext::Web(provider) => provider.do_access_token(account).await, + Prividers::Web(provider) => provider.do_access_token(account).await, #[cfg(feature = "preauth")] - AuthProviderContext::Apple(provider) => provider.do_access_token(account).await, - AuthProviderContext::Platform(provider) => provider.do_access_token(account).await, + Prividers::Apple(provider) => provider.do_access_token(account).await, + Prividers::Platform(provider) => provider.do_access_token(account).await, } } async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()> { match self { - AuthProviderContext::Web(provider) => provider.do_revoke_token(refresh_token).await, + Prividers::Web(provider) => provider.do_revoke_token(refresh_token).await, #[cfg(feature = "preauth")] - AuthProviderContext::Apple(provider) => provider.do_revoke_token(refresh_token).await, - AuthProviderContext::Platform(provider) => { + Prividers::Apple(provider) => provider.do_revoke_token(refresh_token).await, + Prividers::Platform(provider) => { provider.do_revoke_token(refresh_token).await } } @@ -550,10 +550,10 @@ impl AuthProvider for AuthProviderContext { async fn do_refresh_token(&self, refresh_token: &str) -> AuthResult { match self { - AuthProviderContext::Web(provider) => provider.do_refresh_token(refresh_token).await, + Prividers::Web(provider) => provider.do_refresh_token(refresh_token).await, #[cfg(feature = "preauth")] - AuthProviderContext::Apple(provider) => provider.do_refresh_token(refresh_token).await, - AuthProviderContext::Platform(provider) => { + Prividers::Apple(provider) => provider.do_refresh_token(refresh_token).await, + Prividers::Platform(provider) => { provider.do_refresh_token(refresh_token).await } } diff --git a/crates/openai/src/auth/provide/apple.rs b/crates/openai/src/auth/provide/apple.rs index c4a49f687..c9dea2f3f 100644 --- a/crates/openai/src/auth/provide/apple.rs +++ b/crates/openai/src/auth/provide/apple.rs @@ -294,7 +294,7 @@ impl AppleAuthProvider { } impl AuthProvider for AppleAuthProvider { - fn supports(&self, t: &AuthStrategy) -> bool { + fn support(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Apple) } diff --git a/crates/openai/src/auth/provide/mod.rs b/crates/openai/src/auth/provide/mod.rs index 9565bff54..bb20529a7 100644 --- a/crates/openai/src/auth/provide/mod.rs +++ b/crates/openai/src/auth/provide/mod.rs @@ -20,16 +20,20 @@ use typed_builder::TypedBuilder; pub type AuthResult = anyhow::Result; -#[allow(async_fn_in_trait)] -pub trait AuthProvider { +#[trait_variant::make(AuthProvider: Send)] +pub trait LocalAuthProvider { + /// Do the access token authentication process. async fn do_access_token(&self, account: &model::AuthAccount) -> AuthResult; + /// Do the refresh token authentication process. async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()>; + /// Do the refresh token authentication process. async fn do_refresh_token(&self, refresh_token: &str) -> AuthResult; - fn supports(&self, t: &AuthStrategy) -> bool; + /// Check if the provider supports the given auth strategy. + fn support(&self, t: &AuthStrategy) -> bool; } trait RequestContextExt { diff --git a/crates/openai/src/auth/provide/platform.rs b/crates/openai/src/auth/provide/platform.rs index 2e5b6a51b..2a2a89521 100644 --- a/crates/openai/src/auth/provide/platform.rs +++ b/crates/openai/src/auth/provide/platform.rs @@ -259,7 +259,7 @@ impl PlatformAuthProvider { } impl AuthProvider for PlatformAuthProvider { - fn supports(&self, t: &AuthStrategy) -> bool { + fn support(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Platform) } diff --git a/crates/openai/src/auth/provide/web.rs b/crates/openai/src/auth/provide/web.rs index 3335fa937..88cccfa32 100644 --- a/crates/openai/src/auth/provide/web.rs +++ b/crates/openai/src/auth/provide/web.rs @@ -288,7 +288,7 @@ impl WebAuthProvider { } impl AuthProvider for WebAuthProvider { - fn supports(&self, t: &AuthStrategy) -> bool { + fn support(&self, t: &AuthStrategy) -> bool { t.eq(&AuthStrategy::Web) } diff --git a/crates/openai/src/serve/middleware/limit.rs b/crates/openai/src/serve/middleware/limit.rs index 96ca26f55..e8418b1f9 100644 --- a/crates/openai/src/serve/middleware/limit.rs +++ b/crates/openai/src/serve/middleware/limit.rs @@ -6,10 +6,10 @@ use axum::{ response::Response, }; -use super::tokenbucket::{TokenBucket, TokenBucketLimitContext}; +use super::tokenbucket::{TokenBucket, TokenBucketProvider}; pub(crate) async fn limit_middleware( - State(limit): State>, + State(limit): State>, ConnectInfo(socket_addr): ConnectInfo, request: Request, next: Next, diff --git a/crates/openai/src/serve/middleware/tokenbucket.rs b/crates/openai/src/serve/middleware/tokenbucket.rs index 77c587bcb..45595717d 100644 --- a/crates/openai/src/serve/middleware/tokenbucket.rs +++ b/crates/openai/src/serve/middleware/tokenbucket.rs @@ -275,12 +275,12 @@ fn ip_to_number(ip: IpAddr) -> u128 { } } -pub enum TokenBucketLimitContext { +pub enum TokenBucketProvider { Mem(MemTokenBucket), ReDB(RedisTokenBucket<'static>), } -impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketLimitContext { +impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketProvider { fn from(value: (Strategy, bool, u32, u32, u32)) -> Self { let strategy = match value.0 { Strategy::Mem => Self::Mem(MemTokenBucket::new(value.1, value.2, value.3, value.4)), @@ -290,7 +290,7 @@ impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketLimitContext { } } -impl TokenBucket for TokenBucketLimitContext { +impl TokenBucket for TokenBucketProvider { async fn acquire(&self, ip: IpAddr) -> anyhow::Result { let condition = match self { Self::Mem(t) => t.acquire(ip).await, diff --git a/crates/openai/src/serve/mod.rs b/crates/openai/src/serve/mod.rs index 04c0f85f5..f54706dd2 100644 --- a/crates/openai/src/serve/mod.rs +++ b/crates/openai/src/serve/mod.rs @@ -25,7 +25,7 @@ use crate::dns; use crate::proxy::{InnerProxy, Proxy}; use crate::serve::error::ProxyError; use crate::serve::error::ResponseError; -use crate::serve::middleware::tokenbucket::{Strategy, TokenBucketLimitContext}; +use crate::serve::middleware::tokenbucket::{Strategy, TokenBucketProvider}; use crate::{info, warn, with_context}; use crate::{URL_CHATGPT_API, URL_PLATFORM_API}; use axum::body::Body; @@ -158,7 +158,7 @@ impl Serve { // init auth layer provider let app_layer = { - let limit_context = TokenBucketLimitContext::from(( + let limit_context = TokenBucketProvider::from(( Strategy::from_str(self.0.tb_strategy.as_str())?, self.0.tb_enable, self.0.tb_capacity,