From a67c91919ceac866d9018fe7b1647262fc9b5688 Mon Sep 17 00:00:00 2001 From: pmendes Date: Wed, 7 Aug 2024 18:51:59 +0100 Subject: [PATCH] axum rewrite --- Cargo.lock | 232 ++++++++- Cargo.toml | 16 +- crates/daphne-server/Cargo.toml | 10 +- crates/daphne-worker-test/src/lib.rs | 12 +- crates/daphne-worker/Cargo.toml | 8 + crates/daphne-worker/src/durable/mod.rs | 2 +- .../src/durable/test_state_cleaner.rs | 21 +- crates/daphne-worker/src/lib.rs | 4 + .../src/storage_proxy/middleware.rs | 149 ++++++ crates/daphne-worker/src/storage_proxy/mod.rs | 449 +++++++++--------- 10 files changed, 625 insertions(+), 278 deletions(-) create mode 100644 crates/daphne-worker/src/storage_proxy/middleware.rs diff --git a/Cargo.lock b/Cargo.lock index b7471d14a..1f9f2563a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -217,7 +217,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.3.4", "bitflags 1.3.2", "bytes", "futures-util", @@ -242,6 +242,33 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +dependencies = [ + "async-trait", + "axum-core 0.4.3", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 1.0.1", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -259,6 +286,62 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 0.1.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733" +dependencies = [ + "axum 0.7.5", + "axum-core 0.4.3", + "bytes", + "futures-util", + "headers", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.68", +] + [[package]] name = "az" version = "1.2.1" @@ -522,7 +605,7 @@ version = "4.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.68", @@ -575,6 +658,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constcat" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2e5af989b1955b092db01462980c0a286217f86817e12b2c09aea46bd03651" + [[package]] name = "core-foundation" version = "0.9.4" @@ -794,7 +883,7 @@ version = "0.3.0" dependencies = [ "anyhow", "assert_matches", - "axum", + "axum 0.6.20", "clap", "config", "daphne", @@ -805,8 +894,8 @@ dependencies = [ "http 0.2.12", "hyper 0.14.29", "mappable-rc", - "opentelemetry", - "opentelemetry-http", + "opentelemetry 0.23.0", + "opentelemetry-http 0.12.0", "p256", "paste", "prio", @@ -821,7 +910,7 @@ dependencies = [ "tokio", "tower", "tracing", - "tracing-opentelemetry", + "tracing-opentelemetry 0.24.0", "tracing-subscriber", "url", "webpki", @@ -855,15 +944,22 @@ dependencies = [ name = "daphne-worker" version = "0.3.0" dependencies = [ + "axum 0.7.5", + "axum-extra", + "axum-macros", "bincode", + "bytes", "chrono", + "constcat", "daphne", "daphne-service-utils", "futures", "getrandom", "hex", - "opentelemetry", - "opentelemetry-http", + "http 1.1.0", + "http-body-util", + "opentelemetry 0.24.0", + "opentelemetry-http 0.13.0", "paste", "prio", "prometheus", @@ -874,9 +970,10 @@ dependencies = [ "serde", "serde-wasm-bindgen 0.6.5", "serde_json", + "tower-service", "tracing", "tracing-core", - "tracing-opentelemetry", + "tracing-opentelemetry 0.25.0", "tracing-subscriber", "url", "worker", @@ -1340,6 +1437,36 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 1.1.0", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.1.0", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -2033,6 +2160,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-http" version = "0.12.0" @@ -2042,7 +2183,19 @@ dependencies = [ "async-trait", "bytes", "http 0.2.12", - "opentelemetry", + "opentelemetry 0.23.0", +] + +[[package]] +name = "opentelemetry-http" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab" +dependencies = [ + "async-trait", + "bytes", + "http 1.1.0", + "opentelemetry 0.24.0", ] [[package]] @@ -2058,18 +2211,36 @@ dependencies = [ "glob", "lazy_static", "once_cell", - "opentelemetry", + "opentelemetry 0.23.0", "ordered-float", "percent-encoding", "rand", "thiserror", ] +[[package]] +name = "opentelemetry_sdk" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.24.0", + "percent-encoding", + "rand", + "thiserror", +] + [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6" dependencies = [ "num-traits", ] @@ -3168,6 +3339,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -3270,7 +3452,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -3637,8 +3819,26 @@ checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" dependencies = [ "js-sys", "once_cell", - "opentelemetry", - "opentelemetry_sdk", + "opentelemetry 0.23.0", + "opentelemetry_sdk 0.23.0", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.24.0", + "opentelemetry_sdk 0.24.1", "smallvec", "tracing", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index 1de7bcb28..3080e9404 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,10 +32,12 @@ debug = 1 anyhow = "1.0.86" assert_matches = "1.5.0" async-trait = "0.1.80" -axum = "0.6" +axum = { version = "0.7.5", default-features = false } +axum-extra = "0.9" base64 = "0.21.7" # TODO(mendess): delete. Left here for legacy_req_parse. bincode = "1.3.3" +bytes = "1" cap = "0.1.2" capnp = "0.18.13" cfg-if = "1.0.0" @@ -50,13 +52,14 @@ hex = { version = "0.4.3", features = ["serde"] } hpke-rs = "0.2.0" hpke-rs-crypto = "0.2.0" hpke-rs-rust-crypto = "0.2.0" -http = "0.2" +http = "1" +http-body-util = "0.1" hyper = "0.14.29" itertools = "0.12.1" mappable-rc = "0.1.1" matchit = "0.7.3" -opentelemetry = "0.23.0" -opentelemetry-http = "0.12.0" +opentelemetry = "0.24.0" +opentelemetry-http = "0.13.0" p256 = { version = "0.13.2", features = ["ecdsa-core", "ecdsa", "pem"] } paste = "1.0.15" pin-project = "1.1.5" @@ -76,13 +79,14 @@ strum = { version = "0.26.3", features = ["derive"] } thiserror = "1.0.61" tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } tower = "0.4.13" +tower-service = "0.3" tracing = "0.1.40" tracing-core = "0.1.32" -tracing-opentelemetry = "0.24.0" +tracing-opentelemetry = "0.25.0" tracing-subscriber = "0.3.18" url = { version = "2.5.2", features = ["serde"] } webpki = "0.22.4" -worker = "0.3.3" +worker = { version = "0.3.3", features = ["http"] } x509-parser = "0.15.1" [workspace.dependencies.sentry] diff --git a/crates/daphne-server/Cargo.toml b/crates/daphne-server/Cargo.toml index e34817bfd..b3b41f996 100644 --- a/crates/daphne-server/Cargo.toml +++ b/crates/daphne-server/Cargo.toml @@ -13,16 +13,16 @@ repository.workspace = true description = "Workers backend for Daphne" [dependencies] -axum.workspace = true +axum = "0.6.0" daphne = { path = "../daphne" } daphne-service-utils = { path = "../daphne-service-utils", features = ["durable_requests"] } futures.workspace = true hex.workspace = true -http.workspace = true +http = "0.2" hyper.workspace = true mappable-rc.workspace = true -opentelemetry.workspace = true -opentelemetry-http.workspace = true +opentelemetry = "0.23.0" +opentelemetry-http = "0.12.0" p256.workspace = true prio.workspace = true rayon.workspace = true @@ -33,7 +33,7 @@ thiserror.workspace = true tokio.workspace = true tower.workspace = true tracing.workspace = true -tracing-opentelemetry.workspace = true +tracing-opentelemetry = "0.24.0" url.workspace = true [dev-dependencies] diff --git a/crates/daphne-worker-test/src/lib.rs b/crates/daphne-worker-test/src/lib.rs index f58f71714..71412647c 100644 --- a/crates/daphne-worker-test/src/lib.rs +++ b/crates/daphne-worker-test/src/lib.rs @@ -3,7 +3,7 @@ use daphne_worker::initialize_tracing; use tracing::info; -use worker::{event, Env, Request, Response, Result}; +use worker::{event, Env, HttpRequest}; mod durable; mod utils; @@ -12,7 +12,11 @@ mod utils; static CAP: cap::Cap = cap::Cap::new(std::alloc::System, 65_000_000); #[event(fetch, respond_with_errors)] -pub async fn main(req: Request, env: Env, _ctx: worker::Context) -> Result { +pub async fn main( + req: HttpRequest, + env: Env, + _ctx: worker::Context, +) -> worker::Result { // Optionally, get more helpful error messages written to the console in the case of a panic. utils::set_panic_hook(); @@ -20,7 +24,7 @@ pub async fn main(req: Request, env: Env, _ctx: worker::Context) -> Result tracing::Span { let extractor = crate::tracing_utils::HeaderExtractor::new(req); let remote_context = opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor)); - let span = info_span!("DO span", path = req.path()); + let span = info_span!("durable object", path = req.path()); span.set_parent(remote_context); span } diff --git a/crates/daphne-worker/src/durable/test_state_cleaner.rs b/crates/daphne-worker/src/durable/test_state_cleaner.rs index af1ae060a..77fb1fa14 100644 --- a/crates/daphne-worker/src/durable/test_state_cleaner.rs +++ b/crates/daphne-worker/src/durable/test_state_cleaner.rs @@ -3,15 +3,15 @@ use std::{cmp::min, ops::ControlFlow}; -use crate::{durable::create_span_from_request, initialize_tracing, int_err}; +use crate::{durable::create_span_from_request, int_err}; use daphne::messages::TaskId; use daphne_service_utils::durable_requests::bindings::{self, DurableMethod}; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; -use tracing::{error, trace, Instrument}; +use tracing::Instrument; use worker::{ - async_trait, durable_object, wasm_bindgen, wasm_bindgen_futures, Date, Env, ListOptions, - Method, Request, Response, Result, State, Stub, + async_trait, console_debug, console_error, durable_object, wasm_bindgen, wasm_bindgen_futures, + Date, Env, ListOptions, Method, Request, Response, Result, State, Stub, }; use super::GcDurableObject; @@ -27,7 +27,6 @@ pub struct TestStateCleaner { #[durable_object] impl DurableObject for TestStateCleaner { fn new(state: State, env: Env) -> Self { - initialize_tracing(&env); Self { state, env } } @@ -49,17 +48,17 @@ impl TestStateCleaner { bindings::AggregateStore::BINDING => (), s => { let message = format!("GarbageCollector: unrecognized binding: {s}"); - error!("{}", message); + console_error!("{}", message); return Err(int_err(message)); } }; let queued = DurableOrdered::new_roughly_ordered(durable_ref, "object"); queued.put(&self.state).await?; - trace!( + console_debug!( + "registered DO instance for deletion. binding: {binding}, instance: {instance}", binding = queued.as_ref().binding, instance = queued.as_ref().id_hex, - "registered DO instance for deletion", ); Response::from_json(&()) } @@ -87,10 +86,10 @@ impl TestStateCleaner { &(), ) .await?; - trace!( + console_debug!( + "deleted instance. binding: {binding}. instance: {instance}", binding = durable_ref.binding, instance = durable_ref.id_hex, - "deleted instance", ); } @@ -104,7 +103,7 @@ impl TestStateCleaner { req.method(), req.path() ); - error!("{}", message); + console_error!("{}", message); Err(int_err(message)) } } diff --git a/crates/daphne-worker/src/lib.rs b/crates/daphne-worker/src/lib.rs index c0f6bc3af..28e2a9682 100644 --- a/crates/daphne-worker/src/lib.rs +++ b/crates/daphne-worker/src/lib.rs @@ -13,6 +13,10 @@ use tracing::error; use worker::Error; pub use crate::tracing_utils::initialize_tracing; +pub use axum::{ + body::Body, + response::{IntoResponse, Response}, +}; pub use daphne::DapRequest; pub(crate) fn int_err(s: S) -> Error { diff --git a/crates/daphne-worker/src/storage_proxy/middleware.rs b/crates/daphne-worker/src/storage_proxy/middleware.rs new file mode 100644 index 000000000..1f65a0b7e --- /dev/null +++ b/crates/daphne-worker/src/storage_proxy/middleware.rs @@ -0,0 +1,149 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use std::{ + sync::{Arc, OnceLock}, + time::Duration, +}; + +use axum::{ + extract::{Path, State}, + middleware::Next, + response::IntoResponse, +}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + TypedHeader, +}; +use daphne::messages::constant_time_eq; +use http::{Method, StatusCode}; +use tower_service::Service; +use worker::send::SendFuture; + +use super::RequestContext; + +/// Check if the request's authorization. If unauthorized, return the reason why. +pub async fn unauthorized_reason( + ctx: State>, + bearer: TypedHeader>, + request: axum::extract::Request, + mut next: Next, +) -> axum::response::Response { + static TRUSTED_TOKEN: OnceLock> = OnceLock::new(); + + let Some(trusted_token) = TRUSTED_TOKEN.get_or_init(|| { + ctx.env + .var("DAPHNE_SERVER_AUTH_TOKEN") + .ok() + .map(|t| t.to_string()) + }) else { + tracing::warn!("trusted bearer token not configured"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Authorization token for storage proxy is not configured", + ) + .into_response(); + }; + + if !constant_time_eq(bearer.token().as_bytes(), trusted_token.as_bytes()) { + return (StatusCode::UNAUTHORIZED, "Incorrect authorization token").into_response(); + } + + next.call(request.map(axum::body::Body::new)).await.unwrap() +} + +pub async fn time_kv_requests( + ctx: State>, + method: Method, + request: axum::extract::Request, + next: Next, +) -> axum::response::Response { + SendFuture::new(time_kv_requests_impl(ctx, method, request, next)).await +} + +async fn time_kv_requests_impl( + ctx: State>, + method: Method, + request: axum::extract::Request, + mut next: Next, +) -> axum::response::Response { + let start = worker::Date::now(); + let response = match next.call(request).await { + Ok(response) => response, + Err(infallible) => match infallible {}, + }; + let elapsed = elapsed(&start); + match (method, response.status()) { + (Method::GET, status) if status.is_success() => ctx + .metrics + .kv_request_time_seconds_observe("kv_get", "success", elapsed), + (Method::GET, _) => ctx + .metrics + .kv_request_time_seconds_observe("kv_get", "error", elapsed), + (Method::POST, status) if status.is_success() => ctx + .metrics + .kv_request_time_seconds_observe("kv_put", "success", elapsed), + (Method::POST, _) => ctx + .metrics + .kv_request_time_seconds_observe("kv_put", "error", elapsed), + (Method::PUT, status) if status.is_success() => ctx + .metrics + .kv_request_time_seconds_observe("kv_put_if_not_exists", "success", elapsed), + (Method::PUT, _) => { + ctx.metrics + .kv_request_time_seconds_observe("kv_put_if_not_exists", "error", elapsed) + } + (Method::DELETE, status) if status.is_success() => ctx + .metrics + .kv_request_time_seconds_observe("kv_delete", "success", elapsed), + (Method::DELETE, _) => { + ctx.metrics + .kv_request_time_seconds_observe("kv_delete", "error", elapsed) + } + (method, status) => { + tracing::warn!( + ?method, + ?status, + "unexpected method in kv request. Not measuring time" + ) + } + } + response +} + +pub async fn time_do_requests( + ctx: State>, + uri: Path, + request: axum::extract::Request, + next: Next, +) -> axum::response::Response { + SendFuture::new(time_do_requests_impl(ctx, uri, request, next)).await +} + +async fn time_do_requests_impl( + ctx: State>, + Path(uri): Path, + request: axum::extract::Request, + mut next: Next, +) -> axum::response::Response { + let start = worker::Date::now(); + let response = match next.call(request).await { + Ok(response) => response, + Err(infallible) => match infallible {}, + }; + let elapsed = elapsed(&start); + ctx.metrics.durable_request_time_seconds_observe( + &uri, + if response.status().is_success() { + "success" + } else { + "error" + }, + elapsed, + ); + response +} + +fn elapsed(date: &worker::Date) -> Duration { + Duration::from_millis(worker::Date::now().as_millis() - date.as_millis()) +} diff --git a/crates/daphne-worker/src/storage_proxy/mod.rs b/crates/daphne-worker/src/storage_proxy/mod.rs index 36509898e..e5090d7b4 100644 --- a/crates/daphne-worker/src/storage_proxy/mod.rs +++ b/crates/daphne-worker/src/storage_proxy/mod.rs @@ -70,115 +70,121 @@ //! [to_uri]: daphne_service_utils::durable_requests::bindings::DurableMethod::to_uri mod metrics; +mod middleware; -use std::{sync::OnceLock, time::Duration}; +use std::{sync::Arc, time::Duration}; pub use self::metrics::Metrics; -use daphne::auth::BearerToken; +use axum::{ + extract::{Path, State}, + middleware::from_fn_with_state, + response::{IntoResponse, Response}, + routing, +}; +use bytes::Bytes; use daphne::messages::Time; use daphne_service_utils::durable_requests::{ DurableRequest, ObjectIdFrom, DO_PATH_PREFIX, KV_PATH_PREFIX, }; use daphne_service_utils::http_headers::STORAGE_PROXY_PUT_KV_EXPIRATION; +use http::{HeaderMap, StatusCode}; +use opentelemetry_http::HeaderExtractor; use prometheus::Registry; +use tower_service::Service; use tracing::{info_span, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; use url::Url; -use worker::{js_sys::Uint8Array, Delay, Env, Request, RequestInit, Response}; +use worker::HttpRequest; +use worker::{ + js_sys::Uint8Array, send::SendFuture, Delay, Env, HttpResponse, Request, RequestInit, +}; const KV_BINDING_DAP_CONFIG: &str = "DAP_CONFIG"; -struct RequestContext<'e> { - req: Request, - env: &'e Env, +struct RequestContext { + env: Env, metrics: Metrics, } -/// Check if the request's authorization. If unauthorized, return the reason why. -fn unauthorized_reason(ctx: &RequestContext) -> Option> { - static TRUSTED_TOKEN: OnceLock> = OnceLock::new(); +struct Error(worker::Error); - let access_denied = |reason| Response::error(format!("Unauthorized: {reason}"), 401); - let auth = match ctx.req.headers().get("Authorization") { - Ok(Some(auth)) => auth, - Ok(None) => return Some(access_denied("missing Authorization header")), - Err(e) => return Some(Err(e)), - }; - let Some(provided_token) = auth.strip_prefix("Bearer ").map(BearerToken::from) else { - return Some(access_denied("Authorization header has unexpected prefix")); - }; - let Some(trusted_token) = TRUSTED_TOKEN.get_or_init(|| { - ctx.env - .var("DAPHNE_SERVER_AUTH_TOKEN") - .ok() - .map(|t| t.to_string()) - .map(BearerToken::from) - }) else { - tracing::warn!("trusted bearer token not configured"); - return Some(Response::error( - "Authorization token for storage proxy is not configured", - 500, - )); - }; - if &provided_token != trusted_token { - return Some(access_denied("Incorrect authorization token")); +impl From for Error { + fn from(value: worker::kv::KvError) -> Self { + Self(worker::Error::from(value)) + } +} + +impl From for Error { + fn from(value: worker::Error) -> Self { + Self(value) } +} - None +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } } /// Handle a proxy request. This is the entry point of the Worker. -#[allow(clippy::no_effect_underscore_binding)] -pub async fn handle_request( - req: Request, - env: &Env, - registry: &Registry, -) -> worker::Result { - let span = info_span!("handle_request", path = req.path(), method = ?req.method()); +pub async fn handle_request(req: HttpRequest, env: Env, registry: &Registry) -> Response { + let span = info_span!("handle_request", path = req.uri().path(), method = ?req.method()); { - let extractor = crate::tracing_utils::HeaderExtractor::new(&req); + let extractor = HeaderExtractor(req.headers()); let remote_context = opentelemetry::global::get_text_map_propagator(|propagator| { propagator.extract(&extractor) }); span.set_parent(remote_context); } - let mut ctx = RequestContext { + let ctx = Arc::new(RequestContext { metrics: Metrics::new(registry), - req, env, - }; - - if let Some(error_response) = unauthorized_reason(&ctx) { - return error_response; - } - - let path = ctx.req.path(); - if let Some(uri) = path - .strip_prefix(KV_PATH_PREFIX) - .and_then(|s| s.strip_prefix('/')) - { - handle_kv_request(&mut ctx, uri).await - } else if let Some(uri) = path.strip_prefix(DO_PATH_PREFIX) { - handle_do_request(&mut ctx, uri).await - } else { - #[cfg(feature = "test-utils")] - if let Some("") = path.strip_prefix(daphne_service_utils::durable_requests::PURGE_STORAGE) { - return storage_purge(&ctx).await; - } else if let Some("") = - path.strip_prefix(daphne_service_utils::durable_requests::STORAGE_READY) - { - return Response::ok(""); - } - - tracing::error!("path {path:?} was invalid"); - Response::error("invalid base path", 400) - } + }); + + let router = axum::Router::new() + .route( + constcat::concat!(KV_PATH_PREFIX, "/*path"), + routing::get(s2(kv_get)) + .post(s4(kv_put)) + .put(s4(kv_put_if_not_exists)) + .delete(s2(kv_delete)) + .route_layer(from_fn_with_state( + ctx.clone(), + middleware::time_kv_requests, + )), + ) + .route( + constcat::concat!(DO_PATH_PREFIX, "/*path"), + routing::any(s4(handle_do_request)).layer(from_fn_with_state( + ctx.clone(), + middleware::time_do_requests, + )), + ); + + #[cfg(feature = "test-utils")] + let router = router + .route( + daphne_service_utils::durable_requests::PURGE_STORAGE, + routing::any(|ctx| SendFuture::new(storage_purge(ctx))), + ) + .route( + daphne_service_utils::durable_requests::STORAGE_READY, + routing::any(StatusCode::OK), + ); + + let mut router = router + .layer(from_fn_with_state( + ctx.clone(), + middleware::unauthorized_reason, + )) + .with_state(ctx); + router.call(req).await.unwrap() } #[cfg(feature = "test-utils")] -/// Clear all storage. Only available to tests -async fn storage_purge(ctx: &RequestContext<'_>) -> worker::Result { +#[tracing::instrument(skip(ctx))] +async fn storage_purge(ctx: State>) -> impl IntoResponse + 'static { use daphne_service_utils::durable_requests::bindings::{DurableMethod, TestStateCleaner}; let kv_delete = async { @@ -192,7 +198,7 @@ async fn storage_purge(ctx: &RequestContext<'_>) -> worker::Result { let do_delete = async { let mut req = Request::new_with_init( - &format!("https://fake-host{}", TestStateCleaner::DeleteAll.to_uri(),), + &format!("https://fake-host{}", TestStateCleaner::DeleteAll.to_uri()), RequestInit::new().with_method(worker::Method::Post), )?; @@ -206,12 +212,17 @@ async fn storage_purge(ctx: &RequestContext<'_>) -> worker::Result { .await }; - futures::try_join!(kv_delete, do_delete)?; - Response::empty() + futures::try_join!(kv_delete, do_delete) + .map(|_| ()) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()) } -fn parse_expiration_header(ctx: &RequestContext) -> Result, worker::Error> { - let expiration_header = ctx.req.headers().get(STORAGE_PROXY_PUT_KV_EXPIRATION)?; +fn parse_expiration_header(headers: &HeaderMap) -> Result, worker::Error> { + let expiration_header = headers + .get(STORAGE_PROXY_PUT_KV_EXPIRATION) + .map(|h| h.to_str()) + .transpose() + .map_err(|e| worker::Error::RustError(e.to_string()))?; expiration_header .map(|expiration| { expiration.parse::