Skip to content

Commit

Permalink
Merge pull request #60 from Dstack-TEE/tproxy-multi-connect
Browse files Browse the repository at this point in the history
tproxy: Connect to multiple hosts
  • Loading branch information
kvinwang authored Dec 16, 2024
2 parents a2915e7 + 3d94f0d commit 194e0dc
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 60 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ tailf = "0.1.2"
time = "0.3.37"
uuid = { version = "1.11.0", features = ["v4"] }
which = "7.0.0"
smallvec = "1.13.2"
2 changes: 2 additions & 0 deletions tproxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ tproxy-rpc.workspace = true
certbot.workspace = true
bytes.workspace = true
safe-write.workspace = true
smallvec.workspace = true
futures.workspace = true

[target.'cfg(unix)'.dependencies]
nix = { workspace = true, features = ["resource"] }
Expand Down
8 changes: 6 additions & 2 deletions tproxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct ProxyConfig {
pub tappd_port: u16,
pub timeouts: Timeouts,
pub buffer_size: usize,
pub connect_top_n: usize,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -39,6 +40,11 @@ pub struct Timeouts {
pub connect: Duration,
#[serde(with = "serde_duration")]
pub handshake: Duration,
#[serde(with = "serde_duration")]
pub total: Duration,

#[serde(with = "serde_duration")]
pub cache_top_n: Duration,

pub data_timeout_enabled: bool,
#[serde(with = "serde_duration")]
Expand All @@ -47,8 +53,6 @@ pub struct Timeouts {
pub write: Duration,
#[serde(with = "serde_duration")]
pub shutdown: Duration,
#[serde(with = "serde_duration")]
pub total: Duration,
}

#[derive(Debug, Clone, Deserialize)]
Expand Down
2 changes: 1 addition & 1 deletion tproxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn main() -> Result<()> {

let proxy_config = config.proxy.clone();
let pccs_url = config.pccs_url.clone();
let state = main_service::AppState::new(config)?;
let state = main_service::Proxy::new(config)?;
state.lock().reconfigure()?;
proxy::start(proxy_config, state.clone());

Expand Down
92 changes: 72 additions & 20 deletions tproxy/src/main_service.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{
collections::{BTreeMap, BTreeSet},
net::Ipv4Addr,
process::Command,
process::{Command, Stdio},
sync::{Arc, Mutex, MutexGuard, Weak},
time::{Duration, SystemTime, UNIX_EPOCH},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};

use anyhow::{bail, Context, Result};
Expand All @@ -14,38 +14,42 @@ use rand::seq::IteratorRandom;
use rinja::Template as _;
use safe_write::safe_write;
use serde::{Deserialize, Serialize};
use smallvec::{smallvec, SmallVec};
use tproxy_rpc::{
tproxy_server::{TproxyRpc, TproxyServer},
AcmeInfoResponse, GetInfoRequest, GetInfoResponse, HostInfo as PbHostInfo, ListResponse,
RegisterCvmRequest, RegisterCvmResponse, TappdConfig, WireGuardConfig,
};
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};

use crate::{
config::Config,
models::{InstanceInfo, WgConf},
proxy::AddressGroup,
};

#[derive(Clone)]
pub struct AppState {
pub struct Proxy {
pub(crate) config: Arc<Config>,
inner: Arc<Mutex<AppStateInner>>,
inner: Arc<Mutex<ProxyState>>,
}

#[derive(Debug, Serialize, Deserialize)]
struct State {
struct ProxyStateMut {
apps: BTreeMap<String, BTreeSet<String>>,
instances: BTreeMap<String, InstanceInfo>,
allocated_addresses: BTreeSet<Ipv4Addr>,
#[serde(skip)]
top_n: BTreeMap<String, (AddressGroup, Instant)>,
}

pub(crate) struct AppStateInner {
pub(crate) struct ProxyState {
config: Arc<Config>,
state: State,
state: ProxyStateMut,
}

impl AppState {
pub(crate) fn lock(&self) -> MutexGuard<AppStateInner> {
impl Proxy {
pub(crate) fn lock(&self) -> MutexGuard<ProxyState> {
self.inner.lock().expect("Failed to lock AppState")
}

Expand All @@ -56,13 +60,14 @@ impl AppState {
let state_str = fs::read_to_string(state_path).context("Failed to read state")?;
serde_json::from_str(&state_str).context("Failed to load state")?
} else {
State {
ProxyStateMut {
apps: BTreeMap::new(),
top_n: BTreeMap::new(),
instances: BTreeMap::new(),
allocated_addresses: BTreeSet::new(),
}
};
let inner = Arc::new(Mutex::new(AppStateInner {
let inner = Arc::new(Mutex::new(ProxyState {
config: config.clone(),
state,
}));
Expand All @@ -71,7 +76,7 @@ impl AppState {
}
}

fn start_recycle_thread(state: Weak<Mutex<AppStateInner>>, config: Arc<Config>) {
fn start_recycle_thread(state: Weak<Mutex<ProxyState>>, config: Arc<Config>) {
if !config.recycle.enabled {
info!("recycle is disabled");
return;
Expand All @@ -87,7 +92,7 @@ fn start_recycle_thread(state: Weak<Mutex<AppStateInner>>, config: Arc<Config>)
});
}

impl AppStateInner {
impl ProxyState {
fn alloc_ip(&mut self) -> Option<Ipv4Addr> {
for ip in self.config.wg.client_ip_range.hosts() {
if ip == self.config.wg.ip {
Expand Down Expand Up @@ -166,10 +171,49 @@ impl AppStateInner {
Ok(())
}

pub(crate) fn select_a_host(&self, id: &str) -> Option<InstanceInfo> {
pub(crate) fn select_top_n_hosts(&mut self, id: &str) -> Result<AddressGroup> {
let n = self.config.proxy.connect_top_n;
if let Some(instance) = self.state.instances.get(id) {
return Ok(smallvec![instance.ip]);
};
let app_instances = self.state.apps.get(id).context("app not found")?;
if n == 0 {
// fallback to random selection
return Ok(self.random_select_a_host(id).unwrap_or_default());
}
let (top_n, insert_time) = self
.state
.top_n
.entry(id.to_string())
.or_insert((SmallVec::new(), Instant::now()));
if !top_n.is_empty() && insert_time.elapsed() < self.config.proxy.timeouts.cache_top_n {
return Ok(top_n.clone());
}

let handshakes = self.latest_handshakes(None);
let mut instances = match handshakes {
Err(err) => {
warn!("Failed to get handshakes, fallback to random selection: {err}");
return Ok(self.random_select_a_host(id).unwrap_or_default());
}
Ok(handshakes) => app_instances
.iter()
.filter_map(|instance_id| {
let instance = self.state.instances.get(instance_id)?;
let (_, elapsed) = handshakes.get(&instance.public_key)?;
Some((instance.ip, *elapsed))
})
.collect::<SmallVec<[_; 4]>>(),
};
instances.sort_by(|a, b| a.1.cmp(&b.1));
instances.truncate(n);
Ok(instances.into_iter().map(|(ip, _)| ip).collect())
}

fn random_select_a_host(&self, id: &str) -> Option<AddressGroup> {
// Direct instance lookup first
if let Some(info) = self.state.instances.get(id).cloned() {
return Some(info);
return Some(smallvec![info.ip]);
}

let app_instances = self.state.apps.get(id)?;
Expand All @@ -191,9 +235,15 @@ impl AppStateInner {
});

let selected = healthy_instances.choose(&mut rand::thread_rng())?;
self.state.instances.get(selected).cloned()
self.state
.instances
.get(selected)
.map(|info| smallvec![info.ip])
}

/// Get latest handshakes
///
/// Return a map of public key to (timestamp, elapsed)
fn latest_handshakes(
&self,
stale_timeout: Option<Duration>,
Expand All @@ -211,6 +261,8 @@ impl AppStateInner {
.arg("show")
.arg(&self.config.wg.interface)
.arg("latest-handshakes")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.context("failed to execute wg show command")?;

Expand Down Expand Up @@ -304,7 +356,7 @@ impl AppStateInner {

pub struct RpcHandler {
attestation: Option<Attestation>,
state: AppState,
state: Proxy,
}

impl TproxyRpc for RpcHandler {
Expand Down Expand Up @@ -413,14 +465,14 @@ impl TproxyRpc for RpcHandler {
}
}

impl RpcCall<AppState> for RpcHandler {
impl RpcCall<Proxy> for RpcHandler {
type PrpcService = TproxyServer<Self>;

fn into_prpc_service(self) -> Self::PrpcService {
TproxyServer::new(self)
}

fn construct(state: &AppState, attestation: Option<Attestation>) -> Result<Self>
fn construct(state: &Proxy, attestation: Option<Attestation>) -> Result<Self>
where
Self: Sized,
{
Expand Down
4 changes: 2 additions & 2 deletions tproxy/src/main_service/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::*;
use crate::config::{load_config_figment, Config};

fn create_test_state() -> AppState {
fn create_test_state() -> Proxy {
let figment = load_config_figment(None);
let config = figment.focus("core").extract::<Config>().unwrap();
AppState::new(config).expect("failed to create app state")
Proxy::new(config).expect("failed to create app state")
}

#[test]
Expand Down
17 changes: 12 additions & 5 deletions tproxy/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{net::Ipv4Addr, sync::Arc};

use anyhow::{bail, Context, Result};
use sni::extract_sni;
Expand All @@ -10,7 +10,9 @@ use tokio::{
};
use tracing::{debug, error, info};

use crate::{config::ProxyConfig, main_service::AppState};
use crate::{config::ProxyConfig, main_service::Proxy};

pub(crate) type AddressGroup = smallvec::SmallVec<[Ipv4Addr; 4]>;

mod io_bridge;
mod sni;
Expand Down Expand Up @@ -89,7 +91,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {

async fn handle_connection(
mut inbound: TcpStream,
state: AppState,
state: Proxy,
dotted_base_domain: &str,
tls_terminate_proxy: Arc<TlsTerminateProxy>,
) -> Result<()> {
Expand Down Expand Up @@ -126,7 +128,7 @@ async fn handle_connection(
}
}

pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> {
pub async fn run(config: &ProxyConfig, app_state: Proxy) -> Result<()> {
let dotted_base_domain = {
let base_domain = config.base_domain.as_str();
let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain);
Expand Down Expand Up @@ -187,7 +189,7 @@ pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> {
}
}

pub fn start(config: ProxyConfig, app_state: AppState) {
pub fn start(config: ProxyConfig, app_state: Proxy) {
tokio::spawn(async move {
if let Err(err) = run(&config, app_state).await {
error!(
Expand All @@ -197,3 +199,8 @@ pub fn start(config: ProxyConfig, app_state: AppState) {
}
});
}

// async fn connect_to_app(state: &AppState, app_id: &str, port: u16) -> Result<TcpStream> {
// let host = state.lock().select_a_host(app_id).context(format!("tapp {app_id} not found"))?;
// TcpStream::connect((host.ip, port))
// }
Loading

0 comments on commit 194e0dc

Please sign in to comment.