Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: reduce some per-task memory usage #8095

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub async fn task_main(
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();

connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await{
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Ok((socket, Some(addr))) => (socket, addr.ip()),
Err(e) => {
error!("per-client task finished with an error: {e:#}");
Expand All @@ -101,36 +101,38 @@ pub async fn task_main(
error!("missing required client IP");
return;
}
Ok((socket, None)) => (socket, peer_addr.ip())
Ok((socket, None)) => (socket, peer_addr.ip()),
};

match socket.inner.set_nodelay(true) {
Ok(()) => {},
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
},
}
};

let mut ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span.clone();

let res = handle_client(
config,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone())
.await;
let startup = Box::pin(
handle_client(
config,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;

match res {
Err(e) => {
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/proxy/copy_bidirectional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ pub(super) struct CopyBuffer {
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
const DEFAULT_BUF_SIZE: usize = 1024;
khanova marked this conversation as resolved.
Show resolved Hide resolved

impl CopyBuffer {
pub(super) fn new() -> Self {
Expand Down
120 changes: 73 additions & 47 deletions proxy/src/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_util::task::TaskTracker;

use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::read_proxy_protocol;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
Expand Down Expand Up @@ -102,8 +102,6 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`

let server = Builder::new(TokioExecutor::new());

while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
Expand All @@ -127,65 +125,73 @@ pub async fn task_main(
}

let conn_token = cancellation_token.child_token();
let conn = connection_handler(
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conn_token.clone(),
server.clone(),
tls_acceptor.clone(),
conn,
peer_addr,
)
.instrument(http_conn_span);
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
conradludgate marked this conversation as resolved.
Show resolved Hide resolved
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);

connections.spawn(async move {
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token);
conn.await
});
let session_id = uuid::Uuid::new_v4();

let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);

let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};

Box::pin(connection_handler(
config,
backend,
connections2,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
session_id,
))
.await;
}
.instrument(http_conn_span),
);
}

connections.wait().await;

Ok(())
}

/// Handles the TCP lifecycle.
///
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
/// 3. Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: TlsAcceptor,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) {
let session_id = uuid::Uuid::new_v4();

let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);

) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return;
return None;
}
};

Expand All @@ -208,24 +214,44 @@ async fn connection_handler(
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
return None;
}
// The handshake timed out
Err(e) => {
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
return None;
}
};

Some((conn, peer_addr))
}

/// Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: TlsStream<ChainRW<TcpStream>>,
peer_addr: IpAddr,
session_id: uuid::Uuid,
) {
let session_id = AtomicTake::new(session_id);

// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();

let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/serverless/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl PoolingBackend {
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if !force_new {
info!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info).await?
self.pool.get(ctx, &conn_info)?
} else {
info!("pool: pool is disabled");
None
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/serverless/conn_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
}
}

pub async fn get(
pub fn get(
self: &Arc<Self>,
ctx: &mut RequestMonitoring,
conn_info: &ConnInfo,
Expand Down
46 changes: 25 additions & 21 deletions proxy/src/serverless/sql_over_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,27 +533,31 @@ async fn handle_inner(
return Err(SqlOverHttpError::RequestTooLarge);
}

let fetch_and_process_request = async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from);

let authenticate_and_connect = async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from);
let fetch_and_process_request = Box::pin(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
);

let authenticate_and_connect = Box::pin(
async {
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.latency_timer.success();
Ok::<_, HttpConnError>(client)
}
.map_err(SqlOverHttpError::from),
);

let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/serverless/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ pub async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);

let res = handle_client(
let res = Box::pin(handle_client(
config,
&mut ctx,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
)
))
.await;

match res {
Expand Down
Loading