Skip to content

Commit

Permalink
proxy: Implement cancellation rate limiting
Browse files Browse the repository at this point in the history
Implement cancellation rate limiting and ip allowlist checks.
Add ip_allowlist to the cancel closure

Fixes #16456
  • Loading branch information
awarus committed Nov 15, 2024
1 parent ceaa80f commit 5579b38
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 44 deletions.
20 changes: 12 additions & 8 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tokio_postgres::config::SslMode;
use tracing::{info, info_span};

use super::ComputeCredentialKeys;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
Expand Down Expand Up @@ -74,10 +75,10 @@ impl ConsoleRedirectBackend {
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<ConsoleRedirectNodeInfo> {
) -> auth::Result<(ConsoleRedirectNodeInfo, Option<Vec<IpPattern>>)> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(ConsoleRedirectNodeInfo)
.map(|(node_info, ip_allowlist)| (ConsoleRedirectNodeInfo(node_info), ip_allowlist))
}
}

Expand All @@ -102,7 +103,7 @@ async fn authenticate(
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
) -> auth::Result<(NodeInfo, Option<Vec<IpPattern>>)> {
ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect);

// registering waiter can fail if we get unlucky with rng.
Expand Down Expand Up @@ -176,9 +177,12 @@ async fn authenticate(
config.password(password.as_ref());
}

Ok(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
})
Ok((
NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
},
db_info.allowed_ips,
))
}
18 changes: 1 addition & 17 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod local;

use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;

pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
Expand All @@ -30,7 +29,7 @@ use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
Expand Down Expand Up @@ -192,21 +191,6 @@ impl MaskedIp {
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;

impl RateBucketInfo {
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
}

impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
Expand Down
13 changes: 12 additions & 1 deletion proxy/src/bin/local_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use proxy::http::health_server::AppMetrics;
use proxy::intern::RoleNameInt;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, LeakyBucketConfig, RateBucketInfo,
};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
Expand Down Expand Up @@ -67,6 +67,9 @@ struct LocalProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
user_rps_limit: Vec<RateBucketInfo>,
/// Cancel rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
cancel_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
Expand Down Expand Up @@ -201,6 +204,13 @@ async fn main() -> anyhow::Result<()> {
},
));

let cancel_rps_limit = Vec::leak(args.cancel_rps_limit.clone());
RateBucketInfo::validate(cancel_rps_limit)?;

let cancel_rate_limiter = Arc::new(std::sync::Mutex::new(GlobalRateLimiter::new(
cancel_rps_limit.into(),
)));

let task = serverless::task_main(
config,
auth_backend,
Expand All @@ -210,6 +220,7 @@ async fn main() -> anyhow::Result<()> {
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
cancel_rate_limiter,
)),
endpoint_rate_limiter,
);
Expand Down
17 changes: 15 additions & 2 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use proxy::context::parquet::ParquetUploadArgs;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
use proxy::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
EndpointRateLimiter, GlobalRateLimiter, LeakyBucketConfig, RateBucketInfo,
WakeComputeRateLimiter,
};
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
Expand Down Expand Up @@ -163,6 +164,9 @@ struct ProxyCliArgs {
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Cancellation query rate limiter, max number of cancellation requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
cancel_rps_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
Expand Down Expand Up @@ -427,12 +431,21 @@ async fn main() -> anyhow::Result<()> {
)?))),
None => None,
};

let cancel_rps_limit = Vec::leak(args.cancel_rps_limit.clone());
RateBucketInfo::validate(cancel_rps_limit)?;

let cancel_rate_limiter = Arc::new(std::sync::Mutex::new(GlobalRateLimiter::new(
cancel_rps_limit.into(),
)));

let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
Option<Arc<Mutex<RedisPublisherClient>>>,
>::new(
cancel_map.clone(),
redis_publisher,
proxy::metrics::CancellationSource::FromClient,
cancel_rate_limiter,
));

// bit of a hack - find the min rps and max rps supported and turn it into
Expand Down
82 changes: 76 additions & 6 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;

use crate::auth::{check_peer_addr_is_in_list, IpPattern};
use crate::error::ReportableError;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::rate_limiter::GlobalRateLimiter;
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
};
use std::net::IpAddr;

pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
Expand All @@ -29,14 +32,23 @@ pub struct CancellationHandler<P> {
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: CancellationSource,
// rate limiter of cancellation requests
limiter: Arc<std::sync::Mutex<GlobalRateLimiter>>,
}

#[derive(Debug, Error)]
pub(crate) enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),

#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),

#[error("rate limit exceeded")]
RateLimit,

#[error("IP is not allowed")]
Unauthorized,
}

impl ReportableError for CancelError {
Expand All @@ -47,6 +59,8 @@ impl ReportableError for CancelError {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::Unauthorized => crate::error::ErrorKind::User,
}
}
}
Expand Down Expand Up @@ -79,13 +93,28 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
cancellation_handler: self,
}
}

/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
pub(crate) async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
peer_addr: Option<&IpAddr>,
) -> Result<(), CancelError> {
// probably better would be to use lockless rate limiter
if !self.limiter.lock().unwrap().check() {
tracing::debug!("Rate limit exceeded. Skipping cancellation message");
Metrics::get()
.proxy
.cancellation_requests_total
.inc(CancellationRequest {
source: self.from,
kind: crate::metrics::CancellationOutcome::RateLimitExceeded,
});
return Err(CancelError::RateLimit);
}

// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
Expand All @@ -107,6 +136,15 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}
return Ok(());
};

if let Some(addr) = peer_addr {
if !addr.is_unspecified()
&& !check_peer_addr_is_in_list(addr, cancel_closure.ip_allowlist.as_slice())
{
return Err(CancelError::Unauthorized);
}
}

Metrics::get()
.proxy
.cancellation_requests_total
Expand All @@ -130,18 +168,33 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
}

impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: CancellationSource) -> Self {
pub fn new(
map: CancelMap,
from: CancellationSource,
limiter: Arc<std::sync::Mutex<GlobalRateLimiter>>,
) -> Self {
Self {
map,
client: (),
from,
limiter,
}
}
}

impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: CancellationSource) -> Self {
Self { map, client, from }
pub fn new(
map: CancelMap,
client: Option<Arc<Mutex<P>>>,
from: CancellationSource,
limiter: Arc<std::sync::Mutex<GlobalRateLimiter>>,
) -> Self {
Self {
map,
client,
from,
limiter,
}
}
}

Expand All @@ -152,13 +205,19 @@ impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
}

impl CancelClosure {
pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: CancelToken,
ip_allowlist: Vec<IpPattern>,
) -> Self {
Self {
socket_addr,
cancel_token,
ip_allowlist,
}
}
/// Cancels the query running on user's compute node.
Expand All @@ -168,6 +227,9 @@ impl CancelClosure {
info!("query was cancelled");
Ok(())
}
pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec<IpPattern>) {
self.ip_allowlist = ip_allowlist;
}
}

/// Helper for registering query cancellation tokens.
Expand Down Expand Up @@ -201,12 +263,15 @@ impl<P> Drop for Session<P> {
#[cfg(test)]
mod tests {
use super::*;
use crate::rate_limiter::RateBucketInfo;

#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let info = RateBucketInfo::new(50, std::time::Duration::from_secs(5));
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
CancelMap::default(),
CancellationSource::FromRedis,
Arc::new(std::sync::Mutex::new(GlobalRateLimiter::new(vec![info]))),
));

let session = cancellation_handler.clone().get_session();
Expand All @@ -220,15 +285,20 @@ mod tests {

#[tokio::test]
async fn cancel_session_noop_regression() {
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
let info = RateBucketInfo::new(50, std::time::Duration::from_secs(5));
let handler = CancellationHandler::<()>::new(
CancelMap::default(),
CancellationSource::Local,
Arc::new(std::sync::Mutex::new(GlobalRateLimiter::new(vec![info]))),
);
handler
.cancel_session(
CancelKeyData {
backend_pid: 0,
cancel_key: 0,
},
Uuid::new_v4(),
Some(&("127.0.0.1".parse().unwrap())),
)
.await
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl ConnCfg {

// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);

let connection = PostgresConnection {
stream,
Expand Down
Loading

0 comments on commit 5579b38

Please sign in to comment.