Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(repo): Added api key rate limiting on active ws sessions #398

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ STREAM_THROTTLE_LIVE=100
# Authentication & Security
KEYPAIR=generated-p2p-secret
API_PASSWORD=generated-password
API_RATE_LIMIT_DURATION_MILLIS=3000
API_KEY_MAX_CONN_LIMIT=10
GENERATED_KEYS=10

# Database Configuration
Expand Down
5 changes: 5 additions & 0 deletions crates/web-utils/src/server/middlewares/api_key/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub enum ApiKeyError {
Manager(#[from] ApiKeyManagerError),
#[error(transparent)]
InvalidHeader(#[from] InvalidHeaderValue),
#[error("API key rate limit exceeded: {0}")]
RateLimitExceeded(String),
}

impl From<ApiKeyError> for actix_web::Error {
Expand All @@ -44,6 +46,9 @@ impl From<ApiKeyError> for actix_web::Error {
ApiKeyError::InvalidHeader(e) => {
actix_web::error::ErrorUnauthorized(e.to_string())
}
ApiKeyError::RateLimitExceeded(info) => {
actix_web::error::ErrorTooManyRequests(info)
}
}
}
}
30 changes: 28 additions & 2 deletions crates/web-utils/src/server/middlewares/api_key/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use actix_web::http::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use fuel_streams_store::db::Db;
use serde::{Deserialize, Serialize};

use super::{ApiKeyError, ApiKeyStorageError};
use super::{
rate_limiter::RateLimitsController,
ApiKeyError,
ApiKeyStorageError,
};
use crate::server::middlewares::api_key::{
ApiKey,
InMemoryApiKeyStorage,
Expand Down Expand Up @@ -32,14 +36,20 @@ const BEARER: &str = "Bearer";
pub struct ApiKeysManager {
pub db: Arc<Db>,
pub storage: Arc<InMemoryApiKeyStorage>,
pub rate_limiter_controller: Option<Arc<RateLimitsController>>,
}

impl ApiKeysManager {
pub fn new(db: &Arc<Db>) -> Self {
pub fn new(db: &Arc<Db>, max_requests_per_key: Option<u64>) -> Self {
let storage = Arc::new(InMemoryApiKeyStorage::new());
Self {
db: db.to_owned(),
storage,
rate_limiter_controller: max_requests_per_key.map(
|max_requests_per_key| {
RateLimitsController::new(max_requests_per_key).arc()
},
),
}
}

Expand Down Expand Up @@ -104,6 +114,22 @@ impl ApiKeysManager {
}
}

pub fn check_rate_limit(
&self,
api_key: &ApiKey,
) -> Result<(), ApiKeyError> {
if let Some(rate_limiter_controller) =
self.rate_limiter_controller.as_ref()
{
let (is_ok, max_requests_per_key) = rate_limiter_controller
.check_rate_limit(api_key.user_id().into());
if !is_ok {
return Err(ApiKeyError::RateLimitExceeded(format!("Exceeded limit of {max_requests_per_key} active sessions per api key")));
}
}
Ok(())
}

pub fn key_from_headers(
&self,
(headers, query_map): (HeaderMap, HashMap<String, String>),
Expand Down
25 changes: 3 additions & 22 deletions crates/web-utils/src/server/middlewares/api_key/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
collections::HashMap,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use actix_service::Transform;
Expand All @@ -14,22 +13,16 @@ use actix_web::{
};
use futures_util::future::{ready, LocalBoxFuture, Ready};

use super::{rate_limiter::RateLimitsController, ApiKeyError, ApiKeysManager};
use super::{ApiKeyError, ApiKeysManager};

pub struct ApiKeyAuth {
manager: Arc<ApiKeysManager>,
rate_limiter_controller: Option<Arc<RateLimitsController>>,
}

impl ApiKeyAuth {
pub fn new(
manager: &Arc<ApiKeysManager>,
rate_limit_duration: Option<Duration>,
) -> Self {
pub fn new(manager: &Arc<ApiKeysManager>) -> Self {
ApiKeyAuth {
manager: manager.to_owned(),
rate_limiter_controller: rate_limit_duration
.map(|duration| RateLimitsController::new(duration).arc()),
}
}
}
Expand All @@ -53,15 +46,13 @@ where
ready(Ok(ApiKeyAuthMiddleware {
service: Arc::new(service),
manager: self.manager.clone(),
rate_limiter_controller: self.rate_limiter_controller.clone(),
}))
}
}

pub struct ApiKeyAuthMiddleware<S> {
service: Arc<S>,
manager: Arc<ApiKeysManager>,
rate_limiter_controller: Option<Arc<RateLimitsController>>,
}

impl<S, B> actix_service::Service<ServiceRequest> for ApiKeyAuthMiddleware<S>
Expand Down Expand Up @@ -105,7 +96,6 @@ where

let headers = req.headers().clone();
let manager = self.manager.clone();
let rate_limiter_controller = self.rate_limiter_controller.clone();
let service = self.service.clone();

Box::pin(async move {
Expand All @@ -115,16 +105,7 @@ where
.await?
.ok_or_else(|| Error::from(ApiKeyError::Invalid))?;

if let Some(rate_limiter_controller) = rate_limiter_controller {
if !rate_limiter_controller
.check_rate_limit(api_key.user_id().into())
.await
{
return Err(actix_web::error::ErrorTooManyRequests(
"Rate limit per user exceeded",
))
}
}
manager.check_rate_limit(&api_key)?;

match api_key.validate_status() {
Ok(()) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/web-utils/src/server/middlewares/api_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod api_key_impl;
mod errors;
mod manager;
pub mod middleware;
mod rate_limiter;
pub mod rate_limiter;
mod storage;
mod user_id;

Expand Down
58 changes: 36 additions & 22 deletions crates/web-utils/src/server/middlewares/api_key/rate_limiter.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,74 @@
use std::{
sync::Arc,
time::{Duration, Instant},
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};

use dashmap::DashMap;
use tokio::sync::RwLock;

#[derive(Clone, Debug)]
struct RateLimiter {
last_request_time: Arc<RwLock<Instant>>,
rate_limit: Duration,
requests_per_key: Arc<AtomicU64>,
max_requests_per_key: u64,
}

impl RateLimiter {
fn new(rate_limit: Duration) -> Self {
fn new(max_requests_per_key: u64) -> Self {
RateLimiter {
last_request_time: Arc::new(RwLock::new(Instant::now())),
rate_limit,
requests_per_key: Arc::new(AtomicU64::new(0)),
max_requests_per_key,
}
}
}

#[derive(Debug, Default)]
pub struct RateLimitsController {
map: DashMap<u64, RateLimiter>,
rate_limit: Duration,
max_requests_per_key: u64,
}

impl RateLimitsController {
pub fn new(rate_limit: Duration) -> Self {
pub fn new(max_requests_per_key: u64) -> Self {
Self {
map: DashMap::new(),
rate_limit,
max_requests_per_key,
}
}

pub fn arc(self) -> Arc<Self> {
Arc::new(self)
}

pub async fn check_rate_limit(&self, user_id: u64) -> bool {
pub fn add_active_key_sub(&self, user_id: u64) {
if let Some(user_rate_limiter) = self.map.get_mut(&user_id) {
user_rate_limiter
.requests_per_key
.fetch_add(1, Ordering::Relaxed);
}
}

pub fn remove_active_key_sub(&self, user_id: u64) {
if let Some(user_rate_limiter) = self.map.get_mut(&user_id) {
user_rate_limiter
.requests_per_key
.fetch_sub(1, Ordering::Relaxed);
}
}

pub fn check_rate_limit(&self, user_id: u64) -> (bool, u64) {
let user_rate_limiter = self.map.get(&user_id);
match user_rate_limiter {
Some(rate_limiter) => {
let last_request_time =
rate_limiter.last_request_time.read().await;
let elapsed = last_request_time.elapsed();
if elapsed < rate_limiter.rate_limit {
return false;
let requests_per_key =
rate_limiter.requests_per_key.load(Ordering::Relaxed);
if requests_per_key >= rate_limiter.max_requests_per_key {
return (false, rate_limiter.max_requests_per_key);
}
true
(true, self.max_requests_per_key)
}
None => {
let rate_limiter = RateLimiter::new(self.rate_limit);
let rate_limiter = RateLimiter::new(self.max_requests_per_key);
self.map.insert(user_id, rate_limiter);
true
(true, self.max_requests_per_key)
}
}
}
Expand All @@ -64,7 +78,7 @@ impl Clone for RateLimitsController {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
rate_limit: self.rate_limit,
max_requests_per_key: self.max_requests_per_key,
}
}
}
8 changes: 4 additions & 4 deletions services/webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod config;
pub mod metrics;
pub mod server;

use std::{sync::LazyLock, time::Duration};
use std::sync::LazyLock;

pub static STREAMER_MAX_WORKERS: LazyLock<usize> = LazyLock::new(|| {
let available_cpus = num_cpus::get();
Expand All @@ -18,9 +18,9 @@ pub static STREAMER_MAX_WORKERS: LazyLock<usize> = LazyLock::new(|| {
pub static API_PASSWORD: LazyLock<String> =
LazyLock::new(|| dotenvy::var("API_PASSWORD").ok().unwrap_or_default());

pub static API_RATE_LIMIT_DURATION_MILLIS: LazyLock<Option<Duration>> =
pub static API_KEY_MAX_CONN_LIMIT: LazyLock<Option<u64>> =
LazyLock::new(|| {
dotenvy::var("API_RATE_LIMIT_DURATION_MILLIS")
dotenvy::var("API_KEY_MAX_CONN_LIMIT")
.ok()
.and_then(|val| val.parse::<u64>().ok().map(Duration::from_millis))
.and_then(|val| val.parse::<u64>().ok())
});
7 changes: 2 additions & 5 deletions services/webserver/src/server/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use fuel_web_utils::server::{
};

use super::handlers;
use crate::{server::state::ServerState, API_RATE_LIMIT_DURATION_MILLIS};
use crate::server::state::ServerState;

pub fn create_services(
state: ServerState,
Expand All @@ -20,10 +20,7 @@ pub fn create_services(
cfg.app_data(web::Data::new(state.clone()));
cfg.service(
web::resource(with_prefixed_route("ws"))
.wrap(ApiKeyAuth::new(
&state.api_keys_manager,
*API_RATE_LIMIT_DURATION_MILLIS,
))
.wrap(ApiKeyAuth::new(&state.api_keys_manager))
.route(web::get().to({
move |req, body, state: web::Data<ServerState>| {
handlers::websocket::get_websocket(req, body, state)
Expand Down
16 changes: 14 additions & 2 deletions services/webserver/src/server/handlers/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use actix_web::{
use actix_ws::{CloseCode, CloseReason, Message, MessageStream, Session};
use fuel_streams_core::{server::ClientMessage, FuelStreams};
use fuel_web_utils::{
server::middlewares::api_key::ApiKey,
server::middlewares::api_key::{
rate_limiter::RateLimitsController,
ApiKey,
},
telemetry::Telemetry,
};
use futures::{
Expand Down Expand Up @@ -54,12 +57,15 @@ pub async fn get_websocket(
let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
let fuel_streams = state.fuel_streams.clone();
let telemetry = state.telemetry.clone();
let rate_limiter_controller =
state.api_keys_manager.rate_limiter_controller.clone();
actix_web::rt::spawn(handler(
session,
msg_stream,
telemetry,
fuel_streams,
api_key,
rate_limiter_controller,
));
Ok(response)
}
Expand All @@ -70,8 +76,14 @@ async fn handler(
telemetry: Arc<Telemetry<Metrics>>,
fuel_streams: Arc<FuelStreams>,
api_key: ApiKey,
rate_limiter_controller: Option<Arc<RateLimitsController>>,
) -> Result<(), WebsocketError> {
let mut ctx = WsSession::new(&api_key, telemetry, fuel_streams);
let mut ctx = WsSession::new(
&api_key,
telemetry,
fuel_streams,
rate_limiter_controller,
);
tracing::info!(
%api_key,
event = "websocket_connection_opened",
Expand Down
10 changes: 8 additions & 2 deletions services/webserver/src/server/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ use fuel_web_utils::{
telemetry::Telemetry,
};

use crate::{config::Config, metrics::Metrics, API_PASSWORD};
use crate::{
config::Config,
metrics::Metrics,
API_KEY_MAX_CONN_LIMIT,
API_PASSWORD,
};

#[derive(Clone)]
pub struct ServerState {
Expand Down Expand Up @@ -47,7 +52,8 @@ impl ServerState {
let telemetry = Telemetry::new(Some(metrics)).await?;
telemetry.start().await?;

let api_keys_manager = Arc::new(ApiKeysManager::new(&db));
let api_keys_manager =
Arc::new(ApiKeysManager::new(&db, *API_KEY_MAX_CONN_LIMIT));
let initial_keys = api_keys_manager.load_from_db().await?;
for key in initial_keys {
if let Err(e) = api_keys_manager.storage.insert(&key) {
Expand Down
Loading
Loading