Skip to content

Commit

Permalink
proxy: Implement cancellation rate limiting
Browse files Browse the repository at this point in the history
- Implement cancellation rate limiting using LeakyBucketRateLimiter
  per ip subnets.
- Add ip_allowlist to the cancel closure and use it for ip allowlist
  check.
- Add optional peer_addr to CancelSession stream message

Fixes #16456
  • Loading branch information
awarus committed Nov 22, 2024
1 parent ceaa80f commit 0d69762
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 49 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ hyper-util = "0.1"
tokio-tungstenite = "0.21.0"
indexmap = "2"
indoc = "2"
ipnet = "2.9.0"
ipnet = "2.10.0"
itertools = "0.10"
itoa = "1.0.11"
jsonwebtoken = "9"
Expand Down
20 changes: 12 additions & 8 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tokio_postgres::config::SslMode;
use tracing::{info, info_span};

use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
Expand Down Expand Up @@ -74,10 +75,10 @@ impl ConsoleRedirectBackend {
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<ConsoleRedirectNodeInfo> {
) -> auth::Result<(ConsoleRedirectNodeInfo, Option<Vec<IpPattern>>)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(ConsoleRedirectNodeInfo)
.map(|(node_info, ip_allowlist)| (ConsoleRedirectNodeInfo(node_info), ip_allowlist))
}
}

Expand All @@ -102,7 +103,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
) -> auth::Result<(NodeInfo, Option<Vec<IpPattern>>)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);

// registering waiter can fail if we get unlucky with rng.
Expand Down Expand Up @@ -176,9 +177,12 @@ async fn authenticate(
config.password(password.as_ref());
}

Ok(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
})
Ok((
NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
},
db_info.allowed_ips,
))
}
18 changes: 1 addition & 17 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod local;

use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;

pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
Expand All @@ -30,7 +29,7 @@ use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
Expand Down Expand Up @@ -192,21 +191,6 @@ impl MaskedIp {
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;

impl RateBucketInfo {
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
}

impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
Expand Down
3 changes: 2 additions & 1 deletion proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ async fn main() -> anyhow::Result<()> {
)?))),
None => None,
};

let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
cancel_map.clone(),
redis_publisher,
Expand Down
87 changes: 84 additions & 3 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@ use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;

use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::error::ReportableError;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use std::net::IpAddr;

use ipnet::{IpNet, Ipv4Net, Ipv6Net};

pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;

type IpSubnetKey = IpNet;

/// Enables serving `CancelRequest`s.
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
Expand All @@ -29,14 +36,23 @@ pub struct CancellationHandler<P> {
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: CancellationSource,
// rate limiter of cancellation requests
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
}

#[derive(Debug, Error)]
pub(crate) enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),

#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),

#[error("rate limit exceeded")]
RateLimit,

#[error("IP is not allowed")]
IpNotAllowed,
}

impl ReportableError for CancelError {
Expand All @@ -47,6 +63,8 @@ impl ReportableError for CancelError {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
}
}
}
Expand Down Expand Up @@ -79,13 +97,36 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
cancellation_handler: self,
}
}

/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query
pub(crate) async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
peer_addr: &IpAddr,
check_allowed: bool,
) -> Result<(), CancelError> {
// TODO: check for unspecified address is only for backward compatibility, should be removed
if !peer_addr.is_unspecified() {
let subnet_key = match *peer_addr {
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
if !self.limiter.lock().unwrap().check(subnet_key, 1) {
tracing::debug!("Rate limit exceeded. Skipping cancellation message");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
});
return Err(CancelError::RateLimit);
}
}

// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
Expand All @@ -96,7 +137,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
source: self.from,
kind: crate::metrics::CancellationOutcome::NotFound,
});
match self.client.try_publish(key, session_id).await {

if session_id == Uuid::nil() {
// was already published, do not publish it again
return Ok(());
}

match self.client.try_publish(key, session_id, *peer_addr).await {
Ok(()) => {} // do nothing
Err(e) => {
return Err(CancelError::IO(std::io::Error::new(
Expand All @@ -107,6 +154,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
return Ok(());
};

if check_allowed
&& !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
{
return Err(CancelError::IpNotAllowed);
}

Metrics::get()
.proxy
.cancellation_requests_total
Expand Down Expand Up @@ -135,13 +189,29 @@ impl CancellationHandler<()> {
map,
client: (),
from,
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
64,
),
)),
}
}
}

impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
Self { map, client, from }
Self {
map,
client,
from,
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
64,
),
)),
}
}
}

Expand All @@ -152,13 +222,19 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
}

impl CancelClosure {
pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
) -> Self {
Self {
socket_addr,
cancel_token,
ip_allowlist,
}
}
/// Cancels the query running on user's compute node.
Expand All @@ -168,6 +244,9 @@ impl CancelClosure {
info!("query was cancelled");
Ok(())
}
pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
self.ip_allowlist = ip_allowlist;
}
}

/// Helper for registering query cancellation tokens.
Expand Down Expand Up @@ -229,6 +308,8 @@ mod tests {
cancel_key: 0,
},
Uuid::new_v4(),
&("127.0.0.1".parse().unwrap()),
true,
)
.await
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl ConnCfg {

// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);

let connection = PostgresConnection {
stream,
Expand Down
13 changes: 10 additions & 3 deletions proxy/src/console_redirect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,21 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let request_gauge = metrics.connection_requests.guard(proto);

let tls = config.tls_config.as_ref();

let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);

let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
.cancel_session(
cancel_key_data,
ctx.session_id(),
&ctx.peer_addr(),
config.authentication_config.ip_allowlist_check_enabled,
)
.await
.map(|()| None)?)
}
Expand All @@ -174,7 +179,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(

ctx.set_db_options(params.clone());

let user_info = match backend
let (user_info, ip_allowlist) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Expand All @@ -198,6 +203,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.or_else(|e| stream.throw_error(e))
.await?;

node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;

Expand Down
1 change: 1 addition & 0 deletions proxy/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ pub enum CancellationSource {
pub enum CancellationOutcome {
NotFound,
Found,
RateLimitExceeded,
}

#[derive(LabelGroup)]
Expand Down
8 changes: 7 additions & 1 deletion proxy/src/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,18 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);

let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
.cancel_session(
cancel_key_data,
ctx.session_id(),
&ctx.peer_addr(),
config.authentication_config.ip_allowlist_check_enabled,
)
.await
.map(|()| None)?)
}
Expand Down
Loading

0 comments on commit 0d69762

Please sign in to comment.