Skip to content

Commit

Permalink
proxy: Implement cancellation rate limiting
Browse files Browse the repository at this point in the history
- Implement cancellation rate limiting using GlobalRateLimiter and
perform ip allowlist checks.
- Add ip_allowlist to the cancel closure.

Fixes #16456
  • Loading branch information
awarus committed Nov 19, 2024
1 parent ceaa80f commit 1ef092d
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 40 deletions.
20 changes: 12 additions & 8 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tokio_postgres::config::SslMode;
use tracing::{info, info_span};

use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
Expand Down Expand Up @@ -74,10 +75,10 @@ impl ConsoleRedirectBackend {
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<ConsoleRedirectNodeInfo> {
) -> auth::Result<(ConsoleRedirectNodeInfo, Option<Vec<IpPattern>>)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(ConsoleRedirectNodeInfo)
.map(|(node_info, ip_allowlist)| (ConsoleRedirectNodeInfo(node_info), ip_allowlist))
}
}

Expand All @@ -102,7 +103,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
) -> auth::Result<(NodeInfo, Option<Vec<IpPattern>>)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);

// registering waiter can fail if we get unlucky with rng.
Expand Down Expand Up @@ -176,9 +177,12 @@ async fn authenticate(
config.password(password.as_ref());
}

Ok(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
})
Ok((
NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
},
db_info.allowed_ips,
))
}
18 changes: 1 addition & 17 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod local;

use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;

pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
Expand All @@ -30,7 +29,7 @@ use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
Expand Down Expand Up @@ -192,21 +191,6 @@ impl MaskedIp {
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;

impl RateBucketInfo {
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
}

impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
Expand Down
3 changes: 2 additions & 1 deletion proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ async fn main() -> anyhow::Result<()> {
)?))),
None => None,
};

let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
cancel_map.clone(),
redis_publisher,
Expand Down
114 changes: 111 additions & 3 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;

use dashmap::DashMap;
Expand All @@ -10,16 +10,37 @@ use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;

use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::error::ReportableError;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use std::net::IpAddr;

pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;

type IpSubnetKey = IpAddr;

const IPV4_MASK: u32 = 0xFFFF_FF00; // /24 mask
const IPV6_MASK: u128 = 0xFFFF_FFFF_FFFF_FFFF_0000_0000_0000_0000; // /64 mask

fn normalize_ip(ip: &IpAddr) -> IpSubnetKey {
match ip {
IpAddr::V4(v4) => {
let subnet_u32 = u32::from(*v4) & IPV4_MASK;
IpAddr::V4(Ipv4Addr::from(subnet_u32))
}
IpAddr::V6(v6) => {
let subnet_u128 = u128::from(*v6) & IPV6_MASK;
IpAddr::V6(Ipv6Addr::from(subnet_u128))
}
}
}

/// Enables serving `CancelRequest`s.
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
Expand All @@ -29,14 +50,23 @@ pub struct CancellationHandler<P> {
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: CancellationSource,
// rate limiter of cancellation requests
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
}

#[derive(Debug, Error)]
pub(crate) enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),

#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),

#[error("rate limit exceeded")]
RateLimit,

#[error("IP is not allowed")]
IpNotAllowed,
}

impl ReportableError for CancelError {
Expand All @@ -47,6 +77,8 @@ impl ReportableError for CancelError {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed => crate::error::ErrorKind::User,
}
}
}
Expand Down Expand Up @@ -79,13 +111,32 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
cancellation_handler: self,
}
}

/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query
pub(crate) async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
peer_addr: &IpAddr,
check_allowed: bool,
) -> Result<(), CancelError> {
if !peer_addr.is_unspecified() {
let subnet_key = normalize_ip(peer_addr);
if !self.limiter.lock().unwrap().check(subnet_key, 1) {
tracing::debug!("Rate limit exceeded. Skipping cancellation message");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
});
return Err(CancelError::RateLimit);
}
}

// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
Expand All @@ -107,6 +158,13 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
return Ok(());
};

if check_allowed
&& !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
{
return Err(CancelError::IpNotAllowed);
}

Metrics::get()
.proxy
.cancellation_requests_total
Expand Down Expand Up @@ -135,13 +193,29 @@ impl CancellationHandler<()> {
map,
client: (),
from,
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
64,
),
)),
}
}
}

impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
Self { map, client, from }
Self {
map,
client,
from,
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
LeakyBucketRateLimiter::<IpSubnetKey>::DEFAULT,
64,
),
)),
}
}
}

Expand All @@ -152,13 +226,19 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
}

impl CancelClosure {
pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
) -> Self {
Self {
socket_addr,
cancel_token,
ip_allowlist,
}
}
/// Cancels the query running on user's compute node.
Expand All @@ -168,6 +248,9 @@ impl CancelClosure {
info!("query was cancelled");
Ok(())
}
pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
self.ip_allowlist = ip_allowlist;
}
}

/// Helper for registering query cancellation tokens.
Expand Down Expand Up @@ -201,6 +284,7 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
mod tests {
use super::*;
use crate::rate_limiter::RateBucketInfo;

#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
Expand Down Expand Up @@ -229,8 +313,32 @@ mod tests {
cancel_key: 0,
},
Uuid::new_v4(),
&("127.0.0.1".parse().unwrap()),
true,
)
.await
.unwrap();
}

#[test]
fn test_normalize_ipv4() {
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 42));
let expected = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 0));
assert_eq!(normalize_ip(&ip), expected);

let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 128));
let expected = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
assert_eq!(normalize_ip(&ip), expected);
}

#[test]
fn test_normalize_ipv6() {
let ip = IpAddr::V6(Ipv6Addr::from(0x20010DB8000000000000000000000042));
let expected = IpAddr::V6(Ipv6Addr::from(0x20010DB8000000000000000000000000));
assert_eq!(normalize_ip(&ip), expected);

let ip = IpAddr::V6(Ipv6Addr::from(0xFE8000000000000000000000000001));
let expected = IpAddr::V6(Ipv6Addr::from(0xFE8000000000000000000000000000));
assert_eq!(normalize_ip(&ip), expected);
}
}
2 changes: 1 addition & 1 deletion proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl ConnCfg {

// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);

let connection = PostgresConnection {
stream,
Expand Down
13 changes: 10 additions & 3 deletions proxy/src/console_redirect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,21 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let request_gauge = metrics.connection_requests.guard(proto);

let tls = config.tls_config.as_ref();

let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);

let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
.cancel_session(
cancel_key_data,
ctx.session_id(),
&ctx.peer_addr(),
config.authentication_config.ip_allowlist_check_enabled,
)
.await
.map(|()| None)?)
}
Expand All @@ -174,7 +179,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(

ctx.set_db_options(params.clone());

let user_info = match backend
let (user_info, ip_allowlist) = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Expand All @@ -198,6 +203,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.or_else(|e| stream.throw_error(e))
.await?;

node.cancel_closure
.set_ip_allowlist(ip_allowlist.unwrap_or_default());
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;

Expand Down
1 change: 1 addition & 0 deletions proxy/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ pub enum CancellationSource {
pub enum CancellationOutcome {
NotFound,
Found,
RateLimitExceeded,
}

#[derive(LabelGroup)]
Expand Down
Loading

0 comments on commit 1ef092d

Please sign in to comment.