Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: protect ourselves from going over the 1k limit #1444

2 changes: 2 additions & 0 deletions common/src/models/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub enum ErrorKind {
NotReady,
ServiceUnavailable,
DeleteProjectFailed,
ContainerLimit,
}

impl From<ErrorKind> for ApiError {
Expand Down Expand Up @@ -131,6 +132,7 @@ impl From<ErrorKind> for ApiError {
ErrorKind::Forbidden => (StatusCode::FORBIDDEN, "Forbidden"),
ErrorKind::NotReady => (StatusCode::INTERNAL_SERVER_ERROR, "Service not ready"),
ErrorKind::DeleteProjectFailed => (StatusCode::INTERNAL_SERVER_ERROR, "Deleting project failed"),
ErrorKind::ContainerLimit => (StatusCode::SERVICE_UNAVAILABLE, "Our server is full and cannot create / start projects at this time"),
};
Self {
message: error_message.to_string(),
Expand Down
5 changes: 5 additions & 0 deletions common/src/models/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ pub mod name {
&& name.bytes().all(is_valid_char)
&& is_profanity_free(name)
}

/// Is this a cch project
pub fn is_cch_project(&self) -> bool {
self.starts_with("cch23-")
}
}

impl std::ops::Deref for ProjectName {
Expand Down
95 changes: 91 additions & 4 deletions gateway/src/api/latest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,35 @@ async fn create_project(
CustomErrorPath(project_name): CustomErrorPath<ProjectName>,
AxumJson(config): AxumJson<project::Config>,
) -> Result<AxumJson<project::Response>, Error> {
let cch_modifier = project_name.starts_with("cch23-");
let is_cch_project = project_name.is_cch_project();

// Check that the user is within their project limits.
let can_create_project = claim.can_create_project(
service
.get_project_count(&name)
.await?
.saturating_sub(cch_modifier as u32),
.saturating_sub(is_cch_project as u32),
);

if is_cch_project {
let current_container_count = service.count_ready_projects().await?;
oddgrd marked this conversation as resolved.
Show resolved Hide resolved

if current_container_count >= service.container_limit() {
return Err(Error::from_kind(ErrorKind::ContainerLimit));
}
}
Kazy marked this conversation as resolved.
Show resolved Hide resolved

let project = service
.create_project(
project_name.clone(),
name.clone(),
claim.is_admin(),
can_create_project,
if cch_modifier { 5 } else { config.idle_minutes },
if is_cch_project {
5
} else {
config.idle_minutes
},
)
.await?;
let idle_minutes = project.state.idle_minutes();
Expand Down Expand Up @@ -434,6 +446,17 @@ async fn route_project(
req: Request<Body>,
) -> Result<Response<Body>, Error> {
let project_name = scoped_user.scope;
let is_cch_project = project_name.is_cch_project();

// Don't start cch projects if we will be going over the container limit
if is_cch_project {
let current_container_count = service.count_ready_projects().await?;

if current_container_count >= service.container_limit() {
return Err(Error::from_kind(ErrorKind::ContainerLimit));
}
}

let project = service.find_or_start_project(&project_name, sender).await?;
service
.route(&project.state, &project_name, &scoped_user.user.name, req)
Expand Down Expand Up @@ -1125,7 +1148,7 @@ pub mod tests {
use super::*;
use crate::project::ProjectError;
use crate::service::GatewayService;
use crate::tests::{RequestBuilderExt, TestProject, World};
use crate::tests::{RequestBuilderExt, TestGateway, TestProject, World};

#[tokio::test]
async fn api_create_get_delete_projects() -> anyhow::Result<()> {
Expand Down Expand Up @@ -1381,6 +1404,70 @@ pub mod tests {
Ok(())
}

#[test_context(TestGateway)]
#[tokio::test]
async fn api_create_project_above_container_limit(gateway: &mut TestGateway) {
let _ = gateway.create_project("matrix").await;
let cch_code = gateway.try_create_project("cch23-project").await;

assert_eq!(cch_code, StatusCode::SERVICE_UNAVAILABLE);

let normal_code = gateway.try_create_project("project").await;

assert_eq!(
normal_code,
StatusCode::OK,
"it should be possible to still create normal projects"
);
}

#[test_context(TestGateway)]
#[tokio::test]
async fn start_idle_project_when_above_container_limit(gateway: &mut TestGateway) {
let mut cch_idle_project = gateway.create_project("cch23-project").await;

// Run four health checks to get the project to go into idle mode (cch projects always default to 5 min of idle time)
cch_idle_project.run_health_check().await;
cch_idle_project.run_health_check().await;
cch_idle_project.run_health_check().await;
cch_idle_project.run_health_check().await;

cch_idle_project
.wait_for_state(project::State::Stopped)
.await;

let mut normal_idle_project = gateway.create_project("project").await;

// Run two health checks to get the project to go into idle mode
normal_idle_project.run_health_check().await;
normal_idle_project.run_health_check().await;
normal_idle_project.run_health_check().await;
normal_idle_project.run_health_check().await;

normal_idle_project
.wait_for_state(project::State::Stopped)
.await;

let _project_two = gateway.create_project("matrix").await;

// Now try to start the idle projects
let cch_code = cch_idle_project
.router_call(Method::GET, "/services/cch23-project")
.await;

assert_eq!(cch_code, StatusCode::SERVICE_UNAVAILABLE);

let normal_code = normal_idle_project
.router_call(Method::GET, "/services/project")
.await;

assert_eq!(
normal_code,
StatusCode::NOT_FOUND,
"should not be able to find a service since nothing was deployed"
);
}

#[test_context(TestProject)]
#[tokio::test]
async fn api_delete_project_that_is_ready(project: &mut TestProject) {
Expand Down
3 changes: 3 additions & 0 deletions gateway/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub struct ContextArgs {
/// Api key for the user that has rights to start deploys
#[arg(long, default_value = "gateway4deployes")]
pub deploys_api_key: String,
/// Maximum number of containers to start on this node
#[arg(long, default_value = "900")]
pub container_limit: u32,

/// Allow tests to set some extra /etc/hosts
pub extra_hosts: Vec<String>,
Expand Down
108 changes: 73 additions & 35 deletions gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ pub mod tests {
network_name,
proxy_fqdn: FQDN::from_str("test.shuttleapp.rs").unwrap(),
deploys_api_key: "gateway".to_string(),
container_limit: 1,

// Allow access to the auth on the host
extra_hosts: vec!["host.docker.internal:host-gateway".to_string()],
Expand Down Expand Up @@ -768,6 +769,76 @@ pub mod tests {
}
}

/// Helper struct to wrap a bunch of commands to run against gateway's API
pub struct TestGateway {
router: Router,
authorization: Authorization<Bearer>,
service: Arc<GatewayService>,
sender: Sender<BoxedTask>,
world: World,
}

impl TestGateway {
/// Try to create a project and return the request response
pub async fn try_create_project(&mut self, project_name: &str) -> StatusCode {
self.router
.call(
Request::builder()
.method("POST")
.uri(format!("/projects/{project_name}"))
.header("Content-Type", "application/json")
.body("{\"idle_minutes\": 3}".into())
.unwrap()
.with_header(&self.authorization),
)
.await
.unwrap()
.status()
}

/// Create a new project in the test world and return its helping wrapper
pub async fn create_project(&mut self, project_name: &str) -> TestProject {
let status_code = self.try_create_project(project_name).await;

assert_eq!(status_code, StatusCode::OK);

let mut this = TestProject {
authorization: self.authorization.clone(),
project_name: project_name.to_string(),
router: self.router.clone(),
pool: self.world.pool(),
service: self.service.clone(),
sender: self.sender.clone(),
};

this.wait_for_state(project::State::Ready).await;

this
}
}

#[async_trait]
impl AsyncTestContext for TestGateway {
async fn setup() -> Self {
let world = World::new().await;

let (service, sender) = world.service().await;

let router = world.router(service.clone(), sender.clone());
let authorization = world.create_authorization_bearer("neo");

Self {
router,
authorization,
service,
sender,
world,
}
}

async fn teardown(mut self) {}
}

/// Helper struct to wrap a bunch of commands to run against a test project
pub struct TestProject {
router: Router,
Expand Down Expand Up @@ -1024,42 +1095,9 @@ pub mod tests {
#[async_trait]
impl AsyncTestContext for TestProject {
async fn setup() -> Self {
let world = World::new().await;

let (service, sender) = world.service().await;

let mut router = world.router(service.clone(), sender.clone());
let authorization = world.create_authorization_bearer("neo");
let project_name = "matrix";

router
.call(
Request::builder()
.method("POST")
.uri(format!("/projects/{project_name}"))
.header("Content-Type", "application/json")
.body("{\"idle_minutes\": 3}".into())
.unwrap()
.with_header(&authorization),
)
.map_ok(|resp| {
assert_eq!(resp.status(), StatusCode::OK);
})
.await
.unwrap();

let mut this = TestProject {
authorization,
project_name: project_name.to_string(),
router,
pool: world.pool(),
service,
sender,
};
let mut world = TestGateway::setup().await;

this.wait_for_state(project::State::Ready).await;

this
world.create_project("matrix").await
}

async fn teardown(mut self) {
Expand Down
10 changes: 8 additions & 2 deletions gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, error, info, info_span, trace, warn, Instrument};
use tracing::{debug, error, field, info, info_span, trace, warn, Instrument, Span};

#[tokio::main(flavor = "multi_thread")]
async fn main() -> io::Result<()> {
Expand Down Expand Up @@ -101,7 +101,8 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
if let Ok(projects) = gateway.iter_projects().await {
let span = info_span!(
"running health checks",
healthcheck.num_projects = projects.len()
healthcheck.num_projects = projects.len(),
healthcheck.active_projects = field::Empty,
);

let gateway = gateway.clone();
Expand All @@ -116,6 +117,11 @@ async fn start(db: SqlitePool, fs: PathBuf, args: StartArgs) -> io::Result<()> {
handle.await
}
}

let active_projects =
gateway.count_ready_projects().await.unwrap_or_default();
let span = Span::current();
span.record("healthcheck.active_projects", active_projects);
}
.instrument(span)
.await;
Expand Down
20 changes: 20 additions & 0 deletions gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ pub struct GatewayService {
task_router: TaskRouter<BoxedTask>,
state_location: PathBuf,

/// Maximum number of containers the gateway can start
container_limit: u32,

// We store these because we'll need them for the health checks
provisioner_host: Endpoint,
auth_host: Uri,
Expand Down Expand Up @@ -275,6 +278,7 @@ impl GatewayService {
provisioner_host: Endpoint::new(format!("http://{}:8000", args.provisioner_host))
.expect("to have a valid provisioner endpoint"),
auth_host: args.auth_uri,
container_limit: args.container_limit,
}
}

Expand Down Expand Up @@ -323,6 +327,17 @@ impl GatewayService {
Ok(iter)
}

/// The number of projects that are currently in the ready state
pub async fn count_ready_projects(&self) -> Result<u32, Error> {
oddgrd marked this conversation as resolved.
Show resolved Hide resolved
let ready_count: u32 =
query("SELECT COUNT(*) FROM projects, JSON_EACH(project_state) WHERE key = 'ready'")
.fetch_one(&self.db)
.await?
.get::<_, usize>(0);

Ok(ready_count)
}

pub async fn find_project(
&self,
project_name: &ProjectName,
Expand Down Expand Up @@ -924,6 +939,11 @@ impl GatewayService {
pub fn auth_uri(&self) -> &Uri {
&self.auth_host
}

/// Maximum number of containers that can be started by gateway
pub fn container_limit(&self) -> u32 {
self.container_limit
}
}

#[derive(Clone)]
Expand Down