From ceff99e05fe22d3a3eeda2a47ae06223c96f9ab2 Mon Sep 17 00:00:00 2001 From: Phoebe Goldman Date: Tue, 25 Jul 2023 11:19:02 -0400 Subject: [PATCH] Reduce duplication when looking up databases and auth Per Mazdak's comments. Also run rustfmt on this file. --- crates/client-api/src/routes/database.rs | 237 ++++++++++++++++------- 1 file changed, 163 insertions(+), 74 deletions(-) diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 25b145cf89..6f842a3110 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -21,15 +21,15 @@ use spacetimedb_lib::name::PublishOp; use spacetimedb_lib::sats::TypeInSpace; use crate::auth::{ - SpacetimeAuth, SpacetimeAuthHeader, SpacetimeEnergyUsed, SpacetimeExecutionDurationMicros, SpacetimeIdentity, - SpacetimeIdentityToken, + SpacetimeAuth, SpacetimeAuthHeader, SpacetimeEnergyUsed, SpacetimeExecutionDurationMicros, + SpacetimeIdentity, SpacetimeIdentityToken, }; use spacetimedb::address::Address; use spacetimedb::database_logger::DatabaseLogger; use spacetimedb::host::DescribedEntityType; use spacetimedb::identity::Identity; use spacetimedb::json::client_api::StmtResultJson; -use spacetimedb::messages::control_db::{DatabaseInstance, HostType}; +use spacetimedb::messages::control_db::{Database, DatabaseInstance, HostType}; use crate::util::{ByteStringBody, NameOrAddress}; use crate::{log_and_500, ControlCtx, ControlNodeDelegate, WorkerCtx}; @@ -69,10 +69,8 @@ pub async fn call( let args = ReducerArgs::Json(body); let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or_else(|| { log::error!("Could not find database: {}", address.to_hex()); (StatusCode::NOT_FOUND, "No such database.") @@ -98,7 +96,10 @@ pub async fn call( host.spawn_module_host(dbic).await.map_err(log_and_500)? } }; - let result = match module.call_reducer(caller_identity, None, &reducer, args).await { + let result = match module + .call_reducer(caller_identity, None, &reducer, args) + .await + { Ok(rcr) => rcr, Err(e) => { let status_code = match e { @@ -129,7 +130,11 @@ pub async fn call( )) } -fn reducer_outcome_response(identity: &Identity, reducer: &str, outcome: ReducerOutcome) -> (StatusCode, String) { +fn reducer_outcome_response( + identity: &Identity, + reducer: &str, + outcome: ReducerOutcome, +) -> (StatusCode, String) { match outcome { ReducerOutcome::Committed => (StatusCode::OK, "".to_owned()), ReducerOutcome::Failed(errmsg) => { @@ -189,19 +194,20 @@ async fn extract_db_call_info( ) -> Result { let auth = auth.get_or_create(ctx).await?; - let database = ctx - .get_database_by_address(address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(ctx, address) + .await? .ok_or_else(|| { log::error!("Could not find database: {}", address.to_hex()); (StatusCode::NOT_FOUND, "No such database.") })?; - let database_instance = ctx.get_leader_database_instance_by_database(database.id).await.ok_or(( - StatusCode::NOT_FOUND, - "Database instance not scheduled to this node yet.", - ))?; + let database_instance = ctx + .get_leader_database_instance_by_database(database.id) + .await + .ok_or(( + StatusCode::NOT_FOUND, + "Database instance not scheduled to this node yet.", + ))?; Ok(DatabaseInformation { database_instance, @@ -212,7 +218,12 @@ async fn extract_db_call_info( fn entity_description_json(description: TypeInSpace, expand: bool) -> Option { let typ = DescribedEntityType::from_entitydef(description.ty()).as_str(); let len = match description.ty() { - EntityDef::Table(t) => description.resolve(t.data).ty().as_product()?.elements.len(), + EntityDef::Table(t) => description + .resolve(t.data) + .ty() + .as_product()? + .elements + .len(), EntityDef::Reducer(r) => r.args.len(), }; if expand { @@ -262,10 +273,8 @@ pub async fn describe( auth: SpacetimeAuthHeader, ) -> axum::response::Result { let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or((StatusCode::NOT_FOUND, "No such database."))?; let call_info = extract_db_call_info(&*worker_ctx, auth, &address).await?; @@ -294,7 +303,12 @@ pub async fn describe( let description = catalog .get(&entity) .filter(|desc| DescribedEntityType::from_entitydef(desc.ty()) == entity_type) - .ok_or_else(|| (StatusCode::NOT_FOUND, format!("{entity_type} {entity:?} not found")))?; + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + format!("{entity_type} {entity:?} not found"), + ) + })?; let expand = expand.unwrap_or(true); let response_json = json!({ entity: entity_description_json(description, expand) }); @@ -318,10 +332,8 @@ pub async fn catalog( auth: SpacetimeAuthHeader, ) -> axum::response::Result { let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or((StatusCode::NOT_FOUND, "No such database."))?; let call_info = extract_db_call_info(&*worker_ctx, auth, &address).await?; @@ -366,10 +378,8 @@ pub async fn info( Path(InfoParams { name_or_address }): Path, ) -> axum::response::Result { let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or((StatusCode::NOT_FOUND, "No such database."))?; let host_type = match database.host_type { @@ -397,6 +407,11 @@ pub struct LogsQuery { follow: bool, } +fn auth_or_unauth(auth: SpacetimeAuthHeader) -> axum::response::Result { + auth.get() + .ok_or((StatusCode::UNAUTHORIZED, "Invalid credentials").into()) +} + pub async fn logs( State(worker_ctx): State>, Path(LogsParams { name_or_address }): Path, @@ -404,14 +419,16 @@ pub async fn logs( auth: SpacetimeAuthHeader, ) -> axum::response::Result { // You should not be able to read the logs from a database that you do not own - // so, unless you are the owner, this will fail, hence using get() and not get_or_create - let auth = auth.get().ok_or((StatusCode::UNAUTHORIZED, "Invalid credentials."))?; + // so, unless you are the owner, this will fail. + // TODO: This returns `UNAUTHORIZED` on failure, + // while everywhere else we return `BAD_REQUEST`. + // Is this special in some way? Should this change? + // Should all the others change? + let auth = auth_or_unauth(auth)?; let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or((StatusCode::NOT_FOUND, "No such database."))?; if database.identity != auth.identity { @@ -456,7 +473,11 @@ pub async fn logs( std::future::ready(match x { Ok(log) => Some(log), Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(skipped)) => { - log::trace!("Skipped {} lines in log for module {}", skipped, address.to_hex()); + log::trace!( + "Skipped {} lines in log for module {}", + skipped, + address.to_hex() + ); None } }) @@ -483,6 +504,16 @@ fn mime_ndjson() -> mime::Mime { "application/x-ndjson".parse().unwrap() } +async fn worker_ctx_find_database( + worker_ctx: &dyn WorkerCtx, + address: &Address, +) -> Result, StatusCode> { + worker_ctx + .get_database_by_address(address) + .await + .map_err(log_and_500) +} + #[derive(Deserialize)] pub struct SqlParams { name_or_address: NameOrAddress, @@ -503,10 +534,8 @@ pub async fn sql( let auth = auth.get_or_create(&*worker_ctx).await?; let address = name_or_address.resolve(&*worker_ctx).await?; - let database = worker_ctx - .get_database_by_address(&address) - .await - .map_err(log_and_500)? + let database = worker_ctx_find_database(&*worker_ctx, &address) + .await? .ok_or((StatusCode::NOT_FOUND, "No such database."))?; let auth = AuthCtx::new(database.identity, auth.identity); @@ -555,7 +584,11 @@ pub async fn sql( .into_iter() .map(|result| StmtResultJson { schema: result.head.ty(), - rows: result.data.into_iter().map(|x| x.elements).collect::>(), + rows: result + .data + .into_iter() + .map(|x| x.elements) + .collect::>(), }) .collect::>(); @@ -581,7 +614,11 @@ pub async fn dns( Query(DNSQueryParams {}): Query, ) -> axum::response::Result { let domain = database_name.parse().map_err(DomainParsingRejection)?; - let address = ctx.control_db().spacetime_dns(&domain).await.map_err(log_and_500)?; + let address = ctx + .control_db() + .spacetime_dns(&domain) + .await + .map_err(log_and_500)?; let response = if let Some(address) = address { DnsLookupResponse::Success { domain, @@ -613,6 +650,11 @@ pub struct RegisterTldParams { tld: String, } +fn auth_or_bad_request(auth: SpacetimeAuthHeader) -> axum::response::Result { + auth.get() + .ok_or((StatusCode::BAD_REQUEST, "Invalid credentials.").into()) +} + pub async fn register_tld( State(ctx): State>, Query(RegisterTldParams { tld }): Query, @@ -620,9 +662,12 @@ pub async fn register_tld( ) -> axum::response::Result { // You should not be able to publish to a database that you do not own // so, unless you are the owner, this will fail, hence not using get_or_create - let auth = auth.get().ok_or((StatusCode::BAD_REQUEST, "Invalid credentials."))?; + let auth = auth_or_bad_request(auth)?; - let tld = tld.parse::().map_err(DomainParsingRejection)?.into_tld(); + let tld = tld + .parse::() + .map_err(DomainParsingRejection)? + .into_tld(); let result = ctx .control_db() .spacetime_register_tld(tld, auth.identity) @@ -642,7 +687,11 @@ pub struct RequestRecoveryCodeParams { pub async fn request_recovery_code( State(ctx): State>, - Query(RequestRecoveryCodeParams { link, email, identity }): Query, + Query(RequestRecoveryCodeParams { + link, + email, + identity, + }): Query, ) -> axum::response::Result { let Some(sendgrid) = ctx.sendgrid_controller() else { log::error!("A recovery code was requested, but SendGrid is disabled."); @@ -695,7 +744,11 @@ pub struct ConfirmRecoveryCodeParams { /// for an identity that they don't have authority over. pub async fn confirm_recovery_code( State(ctx): State>, - Query(ConfirmRecoveryCodeParams { email, identity, code }): Query, + Query(ConfirmRecoveryCodeParams { + email, + identity, + code, + }): Query, ) -> axum::response::Result { let recovery_code = ctx .control_db() @@ -726,7 +779,11 @@ pub async fn confirm_recovery_code( .any(|a| a.identity == identity) { // This can happen if someone changes their associated email during a recovery request. - return Err((StatusCode::BAD_REQUEST, "No identity associated with that email.").into()); + return Err(( + StatusCode::BAD_REQUEST, + "No identity associated with that email.", + ) + .into()); } // Recovery code is verified, return the identity and token to the user @@ -739,6 +796,16 @@ pub async fn confirm_recovery_code( Ok(axum::Json(result)) } +async fn control_ctx_find_database( + ctx: &dyn ControlCtx, + address: &Address, +) -> Result, StatusCode> { + ctx.control_db() + .get_database_by_address(address) + .await + .map_err(log_and_500) +} + #[derive(Deserialize)] pub struct PublishDatabaseParams {} @@ -779,8 +846,8 @@ pub async fn publish( } = query_params; // You should not be able to publish to a database that you do not own - // so, unless you are the owner, this will fail, hence not using get_or_create - let auth = auth.get().ok_or((StatusCode::BAD_REQUEST, "Invalid credentials."))?; + // so, unless you are the owner, this will fail. + let auth = auth_or_bad_request(auth)?; let specified_address = matches!(name_or_address, Some(NameOrAddress::Address(_))); @@ -792,7 +859,11 @@ pub async fn publish( let domain = name.parse().map_err(DomainParsingRejection)?; // Client specified a name which doesn't yet exist // Create a new DNS record and a new address to assign to it - let address = ctx.control_db().alloc_spacetime_address().await.map_err(log_and_500)?; + let address = ctx + .control_db() + .alloc_spacetime_address() + .await + .map_err(log_and_500)?; let result = ctx .control_db() .spacetime_insert_domain(&address, domain, auth.identity, register_tld) @@ -813,7 +884,10 @@ pub async fn publish( } } else { // No domain or address was specified, create a new one - ctx.control_db().alloc_spacetime_address().await.map_err(log_and_500)? + ctx.control_db() + .alloc_spacetime_address() + .await + .map_err(log_and_500)? }; log::trace!("Publishing to the address: {}", db_address.to_hex()); @@ -831,15 +905,14 @@ pub async fn publish( let trace_log = should_trace(trace_log); - let op = match ctx - .control_db() - .get_database_by_address(&db_address) - .await - .map_err(log_and_500)? - { + let op = match control_ctx_find_database(&*ctx, &db_address).await? { Some(db) => { if Identity::from_slice(db.identity.as_slice()) != auth.identity { - return Err((StatusCode::BAD_REQUEST, "Identity does not own this database.").into()); + return Err(( + StatusCode::BAD_REQUEST, + "Identity does not own this database.", + ) + .into()); } if clear { @@ -864,7 +937,11 @@ pub async fn publish( let success = match res { Ok(success) => success, Err(e) => { - return Err((StatusCode::BAD_REQUEST, format!("Database update rejected: {e}")).into()); + return Err(( + StatusCode::BAD_REQUEST, + format!("Database update rejected: {e}"), + ) + .into()); } }; if let UpdateDatabaseSuccess { @@ -872,7 +949,11 @@ pub async fn publish( migrate_results: _, } = success { - match reducer_outcome_response(&auth.identity, "update", update_result.outcome) { + match reducer_outcome_response( + &auth.identity, + "update", + update_result.outcome, + ) { (StatusCode::OK, _) => {} (status, body) => return Err((status, body).into()), } @@ -886,7 +967,10 @@ pub async fn publish( None if specified_address => { return Err(( StatusCode::NOT_FOUND, - format!("Failed to find database at address: {}", db_address.to_hex()), + format!( + "Failed to find database at address: {}", + db_address.to_hex() + ), ) .into()) } @@ -931,17 +1015,16 @@ pub async fn delete_database( Path(DeleteDatabaseParams { address }): Path, auth: SpacetimeAuthHeader, ) -> axum::response::Result { - let auth = auth.get().ok_or((StatusCode::BAD_REQUEST, "Invalid credentials."))?; + let auth = auth_or_bad_request(auth)?; - match ctx - .control_db() - .get_database_by_address(&address) - .await - .map_err(log_and_500)? - { + match control_ctx_find_database(&*ctx, &address).await? { Some(db) => { if db.identity != auth.identity { - Err((StatusCode::BAD_REQUEST, "Identity does not own this database.").into()) + Err(( + StatusCode::BAD_REQUEST, + "Identity does not own this database.", + ) + .into()) } else { ctx.delete_database(&address) .await @@ -970,7 +1053,7 @@ pub async fn set_name( }): Query, auth: SpacetimeAuthHeader, ) -> axum::response::Result { - let auth = auth.get().ok_or((StatusCode::BAD_REQUEST, "Invalid credentials."))?; + let auth = auth_or_bad_request(auth)?; let database = ctx .control_db() @@ -1027,9 +1110,15 @@ where { use axum::routing::{get, post}; axum::Router::new() - .route("/subscribe/:name_or_address", get(super::subscribe::handle_websocket)) + .route( + "/subscribe/:name_or_address", + get(super::subscribe::handle_websocket), + ) .route("/call/:name_or_address/:reducer", post(call)) - .route("/schema/:name_or_address/:entity_type/:entity", get(describe)) + .route( + "/schema/:name_or_address/:entity_type/:entity", + get(describe), + ) .route("/schema/:name_or_address", get(catalog)) .route("/info/:name_or_address", get(info)) .route("/logs/:name_or_address", get(logs))