Skip to content

Commit

Permalink
feat: trait-variant solves async trait dynamic scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
0x676e67 committed Feb 7, 2024
1 parent 24b0770 commit 2960a42
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 37 deletions.
1 change: 1 addition & 0 deletions crates/openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"], optional =
async-stream = { version = "0.3.5", optional = true }
axum_csrf = { version = "0.8.0", features = ["layer"], optional = true }
serde_urlencoded = { version = "0.7.1", optional = true }
trait-variant = "0.1.1"

[target.'cfg(target_family = "unix")'.dependencies]
nix = { version = "0.27.1", default-features = false, features = ["user"] }
Expand Down
48 changes: 24 additions & 24 deletions crates/openai/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ static EMAIL_REGEX: OnceCell<Regex> = OnceCell::const_new();
#[derive(Clone)]
pub struct AuthClient {
inner: Client,
providers: Vec<AuthProviderContext>,
providers: Vec<Prividers>,
}

impl AuthClient {
Expand Down Expand Up @@ -267,8 +267,8 @@ impl AuthClient {
}

impl AuthProvider for AuthClient {
fn supports(&self, t: &AuthStrategy) -> bool {
self.providers.iter().any(|strategy| strategy.supports(t))
fn support(&self, t: &AuthStrategy) -> bool {
self.providers.iter().any(|strategy| strategy.support(t))
}

async fn do_access_token(
Expand All @@ -288,7 +288,7 @@ impl AuthProvider for AuthClient {

// Try supported providers
for provider in self.providers.iter() {
if provider.supports(&account.option) {
if provider.support(&account.option) {
return provider.do_access_token(account).await;
}
}
Expand All @@ -299,7 +299,7 @@ impl AuthProvider for AuthClient {
async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()> {
let mut result: Option<AuthResult<()>> = None;
for handle in self.providers.iter() {
if handle.supports(&AuthStrategy::Apple) || handle.supports(&AuthStrategy::Platform) {
if handle.support(&AuthStrategy::Apple) || handle.support(&AuthStrategy::Platform) {
let res = handle.do_revoke_token(refresh_token).await;
match res {
Ok(ok) => {
Expand All @@ -320,7 +320,7 @@ impl AuthProvider for AuthClient {
let mut result: Option<AuthResult<model::RefreshToken>> = None;

for handle in self.providers.iter() {
if handle.supports(&AuthStrategy::Apple) || handle.supports(&AuthStrategy::Platform) {
if handle.support(&AuthStrategy::Apple) || handle.support(&AuthStrategy::Platform) {
let res = handle.do_refresh_token(refresh_token).await;
match res {
Ok(ok) => {
Expand Down Expand Up @@ -482,17 +482,17 @@ impl AuthClientBuilder {
let mut providers = Vec::with_capacity(3);

// Web Login privider
providers.push(AuthProviderContext::Web(WebAuthProvider(client.clone())));
providers.push(Prividers::Web(WebAuthProvider(client.clone())));

// Apple Login privider
#[cfg(feature = "preauth")]
providers.push(AuthProviderContext::Apple(AppleAuthProvider {
providers.push(Prividers::Apple(AppleAuthProvider {
inner: client.clone(),
preauth_provider: PreAuthProvider,
}));

// Platform Login privider
providers.push(AuthProviderContext::Platform(PlatformAuthProvider(
providers.push(Prividers::Platform(PlatformAuthProvider(
client.clone(),
)));

Expand All @@ -508,20 +508,20 @@ impl AuthClientBuilder {
}

#[derive(Clone)]
pub(crate) enum AuthProviderContext {
pub(crate) enum Prividers {
Web(WebAuthProvider),
#[cfg(feature = "preauth")]
Apple(AppleAuthProvider),
Platform(PlatformAuthProvider),
}

impl AuthProvider for AuthProviderContext {
fn supports(&self, t: &AuthStrategy) -> bool {
impl AuthProvider for Prividers {
fn support(&self, t: &AuthStrategy) -> bool {
match self {
AuthProviderContext::Web(provider) => provider.supports(t),
Prividers::Web(provider) => provider.support(t),
#[cfg(feature = "preauth")]
AuthProviderContext::Apple(provider) => provider.supports(t),
AuthProviderContext::Platform(provider) => provider.supports(t),
Prividers::Apple(provider) => provider.support(t),
Prividers::Platform(provider) => provider.support(t),
}
}

Expand All @@ -530,30 +530,30 @@ impl AuthProvider for AuthProviderContext {
account: &model::AuthAccount,
) -> AuthResult<model::AccessToken> {
match self {
AuthProviderContext::Web(provider) => provider.do_access_token(account).await,
Prividers::Web(provider) => provider.do_access_token(account).await,
#[cfg(feature = "preauth")]
AuthProviderContext::Apple(provider) => provider.do_access_token(account).await,
AuthProviderContext::Platform(provider) => provider.do_access_token(account).await,
Prividers::Apple(provider) => provider.do_access_token(account).await,
Prividers::Platform(provider) => provider.do_access_token(account).await,
}
}

async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()> {
match self {
AuthProviderContext::Web(provider) => provider.do_revoke_token(refresh_token).await,
Prividers::Web(provider) => provider.do_revoke_token(refresh_token).await,
#[cfg(feature = "preauth")]
AuthProviderContext::Apple(provider) => provider.do_revoke_token(refresh_token).await,
AuthProviderContext::Platform(provider) => {
Prividers::Apple(provider) => provider.do_revoke_token(refresh_token).await,
Prividers::Platform(provider) => {
provider.do_revoke_token(refresh_token).await
}
}
}

async fn do_refresh_token(&self, refresh_token: &str) -> AuthResult<model::RefreshToken> {
match self {
AuthProviderContext::Web(provider) => provider.do_refresh_token(refresh_token).await,
Prividers::Web(provider) => provider.do_refresh_token(refresh_token).await,
#[cfg(feature = "preauth")]
AuthProviderContext::Apple(provider) => provider.do_refresh_token(refresh_token).await,
AuthProviderContext::Platform(provider) => {
Prividers::Apple(provider) => provider.do_refresh_token(refresh_token).await,
Prividers::Platform(provider) => {
provider.do_refresh_token(refresh_token).await
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/openai/src/auth/provide/apple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ impl AppleAuthProvider {
}

impl AuthProvider for AppleAuthProvider {
fn supports(&self, t: &AuthStrategy) -> bool {
fn support(&self, t: &AuthStrategy) -> bool {
t.eq(&AuthStrategy::Apple)
}

Expand Down
10 changes: 7 additions & 3 deletions crates/openai/src/auth/provide/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ use typed_builder::TypedBuilder;

pub type AuthResult<T, E = AuthError> = anyhow::Result<T, E>;

#[allow(async_fn_in_trait)]
pub trait AuthProvider {
#[trait_variant::make(AuthProvider: Send)]
pub trait LocalAuthProvider {
/// Do the access token authentication process.
async fn do_access_token(&self, account: &model::AuthAccount)
-> AuthResult<model::AccessToken>;

/// Do the refresh token authentication process.
async fn do_revoke_token(&self, refresh_token: &str) -> AuthResult<()>;

/// Do the refresh token authentication process.
async fn do_refresh_token(&self, refresh_token: &str) -> AuthResult<model::RefreshToken>;

fn supports(&self, t: &AuthStrategy) -> bool;
/// Check if the provider supports the given auth strategy.
fn support(&self, t: &AuthStrategy) -> bool;
}

trait RequestContextExt {
Expand Down
2 changes: 1 addition & 1 deletion crates/openai/src/auth/provide/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl PlatformAuthProvider {
}

impl AuthProvider for PlatformAuthProvider {
fn supports(&self, t: &AuthStrategy) -> bool {
fn support(&self, t: &AuthStrategy) -> bool {
t.eq(&AuthStrategy::Platform)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/openai/src/auth/provide/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ impl WebAuthProvider {
}

impl AuthProvider for WebAuthProvider {
fn supports(&self, t: &AuthStrategy) -> bool {
fn support(&self, t: &AuthStrategy) -> bool {
t.eq(&AuthStrategy::Web)
}

Expand Down
4 changes: 2 additions & 2 deletions crates/openai/src/serve/middleware/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use axum::{
response::Response,
};

use super::tokenbucket::{TokenBucket, TokenBucketLimitContext};
use super::tokenbucket::{TokenBucket, TokenBucketProvider};

pub(crate) async fn limit_middleware<B>(
State(limit): State<std::sync::Arc<TokenBucketLimitContext>>,
State(limit): State<std::sync::Arc<TokenBucketProvider>>,
ConnectInfo(socket_addr): ConnectInfo<std::net::SocketAddr>,
request: Request<B>,
next: Next<B>,
Expand Down
6 changes: 3 additions & 3 deletions crates/openai/src/serve/middleware/tokenbucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ fn ip_to_number(ip: IpAddr) -> u128 {
}
}

pub enum TokenBucketLimitContext {
pub enum TokenBucketProvider {
Mem(MemTokenBucket),
ReDB(RedisTokenBucket<'static>),
}

impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketLimitContext {
impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketProvider {
fn from(value: (Strategy, bool, u32, u32, u32)) -> Self {
let strategy = match value.0 {
Strategy::Mem => Self::Mem(MemTokenBucket::new(value.1, value.2, value.3, value.4)),
Expand All @@ -290,7 +290,7 @@ impl From<(Strategy, bool, u32, u32, u32)> for TokenBucketLimitContext {
}
}

impl TokenBucket for TokenBucketLimitContext {
impl TokenBucket for TokenBucketProvider {
async fn acquire(&self, ip: IpAddr) -> anyhow::Result<bool> {
let condition = match self {
Self::Mem(t) => t.acquire(ip).await,
Expand Down
4 changes: 2 additions & 2 deletions crates/openai/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::dns;
use crate::proxy::{InnerProxy, Proxy};
use crate::serve::error::ProxyError;
use crate::serve::error::ResponseError;
use crate::serve::middleware::tokenbucket::{Strategy, TokenBucketLimitContext};
use crate::serve::middleware::tokenbucket::{Strategy, TokenBucketProvider};
use crate::{info, warn, with_context};
use crate::{URL_CHATGPT_API, URL_PLATFORM_API};
use axum::body::Body;
Expand Down Expand Up @@ -158,7 +158,7 @@ impl Serve {

// init auth layer provider
let app_layer = {
let limit_context = TokenBucketLimitContext::from((
let limit_context = TokenBucketProvider::from((
Strategy::from_str(self.0.tb_strategy.as_str())?,
self.0.tb_enable,
self.0.tb_capacity,
Expand Down

0 comments on commit 2960a42

Please sign in to comment.