Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): Use an asynchronous DNS resolver #351

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ log = "0.4.20"
anyhow = "1.0.75"
thiserror = "1.0.48"
reqwest = { package = "reqwest-impersonate", version ="0.11.30", default-features = false, features = [
"boring-tls", "impersonate","json", "cookies", "stream", "multipart", "socks", "trust-dns"
"boring-tls", "impersonate","json", "cookies", "stream", "multipart", "socks"
] }
hyper = { package = "hyper_imp", version = "0.14.28", default-features = false, features = [
"client",
] }
trust-dns-resolver = { version = "0.23.2", default-features = false, features = ["system-config", "tokio-runtime"] }
tokio = { version = "1.32.0", features = ["fs", "sync", "signal", "rt-multi-thread"] }
serde_json = "1.0.107"
serde = {version = "1.0.188", features = ["derive"] }
Expand Down
12 changes: 12 additions & 0 deletions openai/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::auth::{self};
use crate::dns;
use crate::{
auth::AuthClient,
context, debug,
proxy::{self, Ipv6CidrExt},
};
use reqwest::{impersonate::Impersonate, Client};
use std::sync::{Arc, OnceLock};
use std::{
net::IpAddr,
sync::atomic::{AtomicUsize, Ordering},
time::Duration,
};
use url::Url;

static DNS_RESOLVER: OnceLock<Arc<dns::TrustDnsResolver>> = OnceLock::new();

/// Client type
#[derive(Clone)]
pub enum ClientAgent {
Expand Down Expand Up @@ -294,6 +298,7 @@ fn build_client(
.danger_accept_invalid_certs(true)
.connect_timeout(Duration::from_secs(config.connect_timeout))
.timeout(Duration::from_secs(config.timeout))
.dns_resolver(get_dns_resolver())
.build()
.expect("Failed to build API client");
client
Expand Down Expand Up @@ -347,6 +352,13 @@ fn get_next_index(len: usize, counter: &AtomicUsize) -> usize {
new
}

fn get_dns_resolver() -> Arc<dns::TrustDnsResolver> {
let dns = DNS_RESOLVER
.get_or_init(|| Arc::new(dns::TrustDnsResolver::default()))
.clone();
dns
}

const RANDOM_IMPERSONATE: [Impersonate; 7] = [
Impersonate::OkHttp3_9,
Impersonate::OkHttp3_11,
Expand Down
59 changes: 59 additions & 0 deletions openai/src/dns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate
use hyper::client::connect::dns::Name;
use reqwest::dns::{Addrs, Resolve, Resolving};
use tokio::sync::OnceCell;
use trust_dns_resolver::config::LookupIpStrategy;
pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::{lookup_ip::LookupIpIntoIter, system_conf, TokioAsyncResolver};

use std::io;
use std::net::SocketAddr;
use std::sync::Arc;

/// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait.
#[derive(Debug, Clone, Default)]
pub(crate) struct TrustDnsResolver {
/// Since we might not have been called in the context of a
/// Tokio Runtime in initialization, so we must delay the actual
/// construction of the resolver.
state: Arc<OnceCell<TokioAsyncResolver>>,
}

struct SocketAddrs {
iter: LookupIpIntoIter,
}

impl Resolve for TrustDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let resolver = self.clone();
Box::pin(async move {
let resolver = resolver.state.get_or_try_init(new_resolver).await?;
let lookup = resolver.lookup_ip(name.as_str()).await?;
let addrs: Addrs = Box::new(SocketAddrs {
iter: lookup.into_iter(),
});
Ok(addrs)
})
}
}

impl Iterator for SocketAddrs {
type Item = SocketAddr;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|ip_addr| SocketAddr::new(ip_addr, 0))
}
}

/// Create a new resolver with the default configuration,
/// which reads from `/etc/resolve.conf`.
async fn new_resolver() -> io::Result<TokioAsyncResolver> {
let (config, mut opts) = system_conf::read_system_conf().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("error reading DNS system conf: {}", e),
)
})?;
opts.ip_strategy = LookupIpStrategy::Ipv6thenIpv4;
Ok(TokioAsyncResolver::tokio(config, opts))
}
2 changes: 2 additions & 0 deletions openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ pub mod auth;
pub mod chatgpt;
pub mod client;
pub mod context;
mod dns;
pub mod eventsource;
pub mod homedir;
pub mod log;
pub mod platform;
pub mod proxy;

#[cfg(feature = "serve")]
pub mod serve;
pub mod token;
Expand Down
7 changes: 5 additions & 2 deletions openai/src/serve/proxy/toapi/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ pub(super) async fn not_stream_handler(
if !finish.is_empty() {
finish_reason = Some(finish.to_owned())
}
let messages = convo.messages();
if let Some(message) = messages.first() {

// If message is not empty, set previous message
if let Some(message) = convo.messages().first() {
previous_message.clear();
previous_message.push_str(message);
}

drop(convo)
}
}
}
Expand Down
Loading