From 0de170059202bd85c597a778ea4fe4ab5c9ca8bb Mon Sep 17 00:00:00 2001 From: Gary Pennington <gary@apollographql.com> Date: Thu, 6 Jun 2024 09:12:23 +0100 Subject: [PATCH 1/4] Add Extensions with_lock() to try and avoid timing issues We hit a lot of problems with timing that causes flaky tests because it's not always easy to know when the compiler drops the Extensions lock. We can side-step this a lot of the time by introducing (and using) a with_lock() function. --- apollo-router/src/context/extensions/sync.rs | 8 ++++++++ apollo-router/src/services/router/service.rs | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/apollo-router/src/context/extensions/sync.rs b/apollo-router/src/context/extensions/sync.rs index c5350f9791..fe63d6fa66 100644 --- a/apollo-router/src/context/extensions/sync.rs +++ b/apollo-router/src/context/extensions/sync.rs @@ -32,6 +32,14 @@ impl ExtensionsMutex { guard: self.extensions.lock(), } } + + /// Locks the extensions for interaction. + /// + /// The lock will be dropped once the closure completes. + pub fn with_lock<'a, T, F: FnOnce(ExtensionsGuard<'a>) -> T>(&'a self, func: F) -> T { + let locked = self.lock(); + func(locked) + } } pub struct ExtensionsGuard<'a> { diff --git a/apollo-router/src/services/router/service.rs b/apollo-router/src/services/router/service.rs index ba35fd6c0e..e08bbfc5bc 100644 --- a/apollo-router/src/services/router/service.rs +++ b/apollo-router/src/services/router/service.rs @@ -441,7 +441,9 @@ impl RouterService { // Regardless of the result, we need to make sure that we cancel any potential batch queries. This is because // custom rust plugins, rhai scripts, and coprocessors can cancel requests at any time and return a GraphQL // error wrapped in an `Ok` or in a `BoxError` wrapped in an `Err`. - let batch_query_opt = context.extensions().lock().remove::<BatchQuery>(); + let batch_query_opt = context + .extensions() + .with_lock(|mut lock| lock.remove::<BatchQuery>()); if let Some(batch_query) = batch_query_opt { // Only proceed with signalling cancelled if the batch_query is not finished if !batch_query.finished() { From 1f2fb324dcc129865b142df35c0b3d260d7cfbb1 Mon Sep 17 00:00:00 2001 From: Gary Pennington <gary@apollographql.com> Date: Thu, 6 Jun 2024 09:32:33 +0100 Subject: [PATCH 2/4] add a changeset --- .../feat_garypen_router_340_extensions_with_lock.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .changesets/feat_garypen_router_340_extensions_with_lock.md diff --git a/.changesets/feat_garypen_router_340_extensions_with_lock.md b/.changesets/feat_garypen_router_340_extensions_with_lock.md new file mode 100644 index 0000000000..3bfadfa59a --- /dev/null +++ b/.changesets/feat_garypen_router_340_extensions_with_lock.md @@ -0,0 +1,7 @@ +### Add Extensions with_lock() to try and avoid timing issues ([PR #5360](https://github.com/apollographql/router/pull/5360)) + +It's easy to trip over issues when interacting with Extensions because we inadvertently hold locks for too long. This can be a source of bugs in the router and causes a lot of tests to be flaky. + +with_lock() avoids this kind of problem by explicitly restricting the lifetime of the Extensions lock. + +By [@garypen](https://github.com/garypen) in https://github.com/apollographql/router/pull/5360 \ No newline at end of file From a0963c4883af0a76e837cabdf5e6651c1c175c2f Mon Sep 17 00:00:00 2001 From: Gary Pennington <gary@apollographql.com> Date: Thu, 6 Jun 2024 14:27:29 +0100 Subject: [PATCH 3/4] Replace all remaining usage of lock() with with_lock() Only review this commit if you are very interested in the details. --- .../axum_factory/axum_http_server_factory.rs | 4 +- apollo-router/src/context/extensions/sync.rs | 7 +- apollo-router/src/context/mod.rs | 15 ++- .../src/plugins/authentication/subgraph.rs | 14 ++- .../plugins/authorization/authenticated.rs | 12 +- .../src/plugins/authorization/mod.rs | 10 +- apollo-router/src/plugins/cache/entity.rs | 29 +++-- .../cost_calculator/static_cost.rs | 3 +- .../src/plugins/demand_control/mod.rs | 45 +++----- .../strategy/static_estimated.rs | 33 +++--- .../plugins/demand_control/strategy/test.rs | 76 +++++++------ apollo-router/src/plugins/file_uploads/mod.rs | 25 ++--- .../src/plugins/progressive_override/mod.rs | 2 +- .../src/plugins/progressive_override/tests.rs | 9 +- .../src/plugins/record_replay/record.rs | 76 +++++++------ .../plugins/telemetry/config_new/cost/mod.rs | 13 ++- .../plugins/telemetry/config_new/events.rs | 9 +- .../telemetry/config_new/graphql/mod.rs | 4 +- .../plugins/telemetry/config_new/selectors.rs | 19 ++-- apollo-router/src/plugins/telemetry/mod.rs | 104 ++++++++---------- .../plugins/traffic_shaping/deduplication.rs | 3 +- .../src/query_planner/bridge_query_planner.rs | 30 ++--- .../query_planner/caching_query_planner.rs | 81 +++++++------- apollo-router/src/query_planner/execution.rs | 4 +- apollo-router/src/services/http/service.rs | 4 +- .../layers/allow_only_http_post_mutations.rs | 13 +-- apollo-router/src/services/layers/apq.rs | 8 +- .../services/layers/content_negotiation.rs | 14 ++- .../services/layers/persisted_queries/mod.rs | 24 ++-- .../src/services/layers/query_analysis.rs | 13 +-- apollo-router/src/services/router/service.rs | 63 +++++------ .../src/services/subgraph_service.rs | 33 ++---- .../src/services/supergraph/service.rs | 40 +++---- .../src/services/supergraph/tests.rs | 16 ++- 34 files changed, 416 insertions(+), 439 deletions(-) diff --git a/apollo-router/src/axum_factory/axum_http_server_factory.rs b/apollo-router/src/axum_factory/axum_http_server_factory.rs index 39a267fe9e..e54433622c 100644 --- a/apollo-router/src/axum_factory/axum_http_server_factory.rs +++ b/apollo-router/src/axum_factory/axum_http_server_factory.rs @@ -754,7 +754,9 @@ impl<'a> Drop for CancelHandler<'a> { self.span .in_scope(|| tracing::error!("broken pipe: the client closed the connection")); } - self.context.extensions().lock().insert(CanceledRequest); + self.context + .extensions() + .with_lock(|mut lock| lock.insert(CanceledRequest)); } } } diff --git a/apollo-router/src/context/extensions/sync.rs b/apollo-router/src/context/extensions/sync.rs index fe63d6fa66..bc35c0b5a8 100644 --- a/apollo-router/src/context/extensions/sync.rs +++ b/apollo-router/src/context/extensions/sync.rs @@ -25,6 +25,7 @@ impl ExtensionsMutex { /// Doing so may cause performance degradation or even deadlocks. /// /// See related clippy lint for examples: <https://rust-lang.github.io/rust-clippy/master/index.html#/await_holding_lock> + #[deprecated] pub fn lock(&self) -> ExtensionsGuard { ExtensionsGuard { #[cfg(debug_assertions)] @@ -37,7 +38,11 @@ impl ExtensionsMutex { /// /// The lock will be dropped once the closure completes. pub fn with_lock<'a, T, F: FnOnce(ExtensionsGuard<'a>) -> T>(&'a self, func: F) -> T { - let locked = self.lock(); + let locked = ExtensionsGuard { + #[cfg(debug_assertions)] + start: Instant::now(), + guard: self.extensions.lock(), + }; func(locked) } } diff --git a/apollo-router/src/context/mod.rs b/apollo-router/src/context/mod.rs index ef71f953ea..fee4e29e72 100644 --- a/apollo-router/src/context/mod.rs +++ b/apollo-router/src/context/mod.rs @@ -271,9 +271,7 @@ impl Context { #[doc(hidden)] pub fn unsupported_executable_document(&self) -> Option<Arc<Valid<ExecutableDocument>>> { self.extensions() - .lock() - .get::<ParsedDocument>() - .map(|d| d.executable.clone()) + .with_lock(|lock| lock.get::<ParsedDocument>().map(|d| d.executable.clone())) } } @@ -421,10 +419,11 @@ mod test { fn context_extensions() { // This is mostly tested in the extensions module. let c = Context::new(); - let mut extensions = c.extensions().lock(); - extensions.insert(1usize); - let v = extensions.get::<usize>(); - assert_eq!(v, Some(&1usize)); + c.extensions().with_lock(|mut lock| lock.insert(1usize)); + let v = c + .extensions() + .with_lock(|lock| lock.get::<usize>().cloned()); + assert_eq!(v, Some(1usize)); } #[test] @@ -455,7 +454,7 @@ mod test { let document = Query::parse_document("{ me }", None, &schema, &Configuration::default()).unwrap(); assert!(c.unsupported_executable_document().is_none()); - c.extensions().lock().insert(document); + c.extensions().with_lock(|mut lock| lock.insert(document)); assert!(c.unsupported_executable_document().is_some()); } } diff --git a/apollo-router/src/plugins/authentication/subgraph.rs b/apollo-router/src/plugins/authentication/subgraph.rs index 40006c2fb0..5b5df352c2 100644 --- a/apollo-router/src/plugins/authentication/subgraph.rs +++ b/apollo-router/src/plugins/authentication/subgraph.rs @@ -374,7 +374,9 @@ impl SubgraphAuth { ServiceBuilder::new() .map_request(move |req: SubgraphRequest| { let signing_params = signing_params.clone(); - req.context.extensions().lock().insert(signing_params); + req.context + .extensions() + .with_lock(|mut lock| lock.insert(signing_params)); req }) .service(service) @@ -644,11 +646,11 @@ mod test { request: &SubgraphRequest, service_name: String, ) -> hyper::Request<hyper::Body> { - let signing_params = { - let ctx = request.context.extensions().lock(); - let sp = ctx.get::<Arc<SigningParamsConfig>>(); - sp.cloned().unwrap() - }; + let signing_params = request + .context + .extensions() + .with_lock(|lock| lock.get::<Arc<SigningParamsConfig>>().cloned()) + .unwrap(); let http_request = request .clone() diff --git a/apollo-router/src/plugins/authorization/authenticated.rs b/apollo-router/src/plugins/authorization/authenticated.rs index bca06237c6..f88796ce4d 100644 --- a/apollo-router/src/plugins/authorization/authenticated.rs +++ b/apollo-router/src/plugins/authorization/authenticated.rs @@ -1661,11 +1661,13 @@ mod tests { .unwrap();*/ let mut headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue> = MultiMap::new(); headers.insert("Accept".into(), "multipart/mixed;deferSpec=20220824".into()); - context.extensions().lock().insert(ClientRequestAccepts { - multipart_defer: true, - multipart_subscription: true, - json: true, - wildcard: true, + context.extensions().with_lock(|mut lock| { + lock.insert(ClientRequestAccepts { + multipart_defer: true, + multipart_subscription: true, + json: true, + wildcard: true, + }) }); let request = supergraph::Request::fake_builder() .query("query { orga(id: 1) { id creatorUser { id } ... @defer { nonNullId } } }") diff --git a/apollo-router/src/plugins/authorization/mod.rs b/apollo-router/src/plugins/authorization/mod.rs index f3cb7431f4..14b43ec039 100644 --- a/apollo-router/src/plugins/authorization/mod.rs +++ b/apollo-router/src/plugins/authorization/mod.rs @@ -291,10 +291,12 @@ impl AuthorizationPlugin { .unwrap_or_default(); policies.sort(); - context.extensions().lock().insert(CacheKeyMetadata { - is_authenticated, - scopes, - policies, + context.extensions().with_lock(|mut lock| { + lock.insert(CacheKeyMetadata { + is_authenticated, + scopes, + policies, + }) }); } diff --git a/apollo-router/src/plugins/cache/entity.rs b/apollo-router/src/plugins/cache/entity.rs index 4839b9197a..914ef61589 100644 --- a/apollo-router/src/plugins/cache/entity.rs +++ b/apollo-router/src/plugins/cache/entity.rs @@ -163,11 +163,11 @@ impl Plugin for EntityCache { fn supergraph_service(&self, service: supergraph::BoxService) -> supergraph::BoxService { ServiceBuilder::new() .map_response(|mut response: supergraph::Response| { - if let Some(cache_control) = { - let lock = response.context.extensions().lock(); - let cache_control = lock.get::<CacheControl>().cloned(); - cache_control - } { + if let Some(cache_control) = response + .context + .extensions() + .with_lock(|lock| lock.get::<CacheControl>().cloned()) + { let _ = cache_control.to_headers(response.response.headers_mut()); } @@ -437,7 +437,10 @@ async fn cache_lookup_root( match cache_result { Some(value) => { - request.context.extensions().lock().insert(value.0.control); + request + .context + .extensions() + .with_lock(|mut lock| lock.insert(value.0.control)); Ok(ControlFlow::Break( subgraph::Response::builder() @@ -519,12 +522,14 @@ async fn cache_lookup_entities( } fn update_cache_control(context: &Context, cache_control: &CacheControl) { - if let Some(c) = context.extensions().lock().get_mut::<CacheControl>() { - *c = c.merge(cache_control); - return; - } - //FIXME: race condition. We need an Entry API for private entries - context.extensions().lock().insert(cache_control.clone()); + context.extensions().with_lock(|mut lock| { + if let Some(c) = lock.get_mut::<CacheControl>() { + *c = c.merge(cache_control); + } else { + //FIXME: race condition. We need an Entry API for private entries + lock.insert(cache_control.clone()); + } + }) } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs index 3ea6ca4ba1..7bf870cbff 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs +++ b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs @@ -469,7 +469,8 @@ mod tests { .unwrap(); let ctx = Context::new(); - ctx.extensions().lock().insert::<ParsedDocument>(query); + ctx.extensions() + .with_lock(|mut lock| lock.insert::<ParsedDocument>(query)); let planner_res = planner .call(QueryPlannerRequest::new(query_str.to_string(), None, ctx)) diff --git a/apollo-router/src/plugins/demand_control/mod.rs b/apollo-router/src/plugins/demand_control/mod.rs index 2e59170b73..fc36115b2e 100644 --- a/apollo-router/src/plugins/demand_control/mod.rs +++ b/apollo-router/src/plugins/demand_control/mod.rs @@ -203,9 +203,9 @@ pub(crate) struct DemandControl { impl DemandControl { fn report_operation_metric(context: Context) { - let guard = context.extensions().lock(); - let cost_context = guard.get::<CostContext>(); - let result = cost_context.map_or("NO_CONTEXT", |c| c.result); + let result = context + .extensions() + .with_lock(|lock| lock.get::<CostContext>().map_or("NO_CONTEXT", |c| c.result)); u64_counter!( "apollo.router.operations.demand_control", "Total operations with demand control enabled", @@ -237,7 +237,9 @@ impl Plugin for DemandControl { let strategy = self.strategy_factory.create(); ServiceBuilder::new() .checkpoint(move |req: execution::Request| { - req.context.extensions().lock().insert(strategy.clone()); + req.context + .extensions() + .with_lock(|mut lock| lock.insert(strategy.clone())); // On the request path we need to check for estimates, checkpoint is used to do this, short-circuiting the request if it's too expensive. Ok(match strategy.on_execution_request(&req) { Ok(_) => ControlFlow::Continue(req), @@ -258,13 +260,9 @@ impl Plugin for DemandControl { .context .unsupported_executable_document() .expect("must have document"); - let strategy = resp - .context - .extensions() - .lock() - .get::<Strategy>() - .expect("must have strategy") - .clone(); + let strategy = resp.context.extensions().with_lock(|lock| { + lock.get::<Strategy>().expect("must have strategy").clone() + }); let context = resp.context.clone(); // We want to sequence this code to run after all the subgraph responses have been scored. @@ -323,13 +321,9 @@ impl Plugin for DemandControl { } else { ServiceBuilder::new() .checkpoint(move |req: subgraph::Request| { - let strategy = req - .context - .extensions() - .lock() - .get::<Strategy>() - .expect("must have strategy") - .clone(); + let strategy = req.context.extensions().with_lock(|lock| { + lock.get::<Strategy>().expect("must have strategy").clone() + }); // On the request path we need to check for estimates, checkpoint is used to do this, short-circuiting the request if it's too expensive. Ok(match strategy.on_subgraph_request(&req) { @@ -355,13 +349,9 @@ impl Plugin for DemandControl { }, |req: Arc<Valid<ExecutableDocument>>, fut| async move { let resp: subgraph::Response = fut.await?; - let strategy = resp - .context - .extensions() - .lock() - .get::<Strategy>() - .expect("must have strategy") - .clone(); + let strategy = resp.context.extensions().with_lock(|lock| { + lock.get::<Strategy>().expect("must have strategy").clone() + }); Ok(match strategy.on_subgraph_response(req.as_ref(), &resp) { Ok(_) => resp, Err(err) => subgraph::Response::builder() @@ -556,7 +546,7 @@ mod test { let strategy = plugin.strategy_factory.create(); let ctx = context(); - ctx.extensions().lock().insert(strategy); + ctx.extensions().with_lock(|mut lock| lock.insert(strategy)); let mut req = subgraph::Request::fake_builder() .subgraph_name("test") .context(ctx) @@ -582,8 +572,7 @@ mod test { }; let ctx = Context::new(); ctx.extensions() - .lock() - .insert(ParsedDocument::new(parsed_document)); + .with_lock(|mut lock| lock.insert(ParsedDocument::new(parsed_document))); ctx } diff --git a/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs b/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs index c7aacd8b8d..6e7447d6b4 100644 --- a/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs +++ b/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs @@ -20,19 +20,20 @@ impl StrategyImpl for StaticEstimated { self.cost_calculator .planned(&request.query_plan) .and_then(|cost| { - let mut extensions = request.context.extensions().lock(); - let cost_result = extensions.get_or_default_mut::<CostContext>(); - cost_result.estimated = cost; - if cost > self.max { - Err( - cost_result.result(DemandControlError::EstimatedCostTooExpensive { - estimated_cost: cost, - max_cost: self.max, - }), - ) - } else { - Ok(()) - } + request.context.extensions().with_lock(|mut lock| { + let cost_result = lock.get_or_default_mut::<CostContext>(); + cost_result.estimated = cost; + if cost > self.max { + Err( + cost_result.result(DemandControlError::EstimatedCostTooExpensive { + estimated_cost: cost, + max_cost: self.max, + }), + ) + } else { + Ok(()) + } + }) }) } @@ -56,9 +57,9 @@ impl StrategyImpl for StaticEstimated { ) -> Result<(), DemandControlError> { if response.data.is_some() { let cost = self.cost_calculator.actual(request, response)?; - let mut extensions = context.extensions().lock(); - let cost_result = extensions.get_or_default_mut::<CostContext>(); - cost_result.actual = cost; + context + .extensions() + .with_lock(|mut lock| lock.get_or_default_mut::<CostContext>().actual = cost); } Ok(()) } diff --git a/apollo-router/src/plugins/demand_control/strategy/test.rs b/apollo-router/src/plugins/demand_control/strategy/test.rs index f50f290e5a..347f77d79f 100644 --- a/apollo-router/src/plugins/demand_control/strategy/test.rs +++ b/apollo-router/src/plugins/demand_control/strategy/test.rs @@ -17,30 +17,32 @@ pub(crate) struct Test { impl StrategyImpl for Test { fn on_execution_request(&self, request: &Request) -> Result<(), DemandControlError> { - let mut extensions = request.context.extensions().lock(); - let cost_context = extensions.get_or_default_mut::<CostContext>(); - match self { - Test { - stage: TestStage::ExecutionRequest, - error, - } => Err(cost_context.result(error.into())), - _ => Ok(()), - } + request.context.extensions().with_lock(|mut lock| { + let cost_context = lock.get_or_default_mut::<CostContext>(); + match self { + Test { + stage: TestStage::ExecutionRequest, + error, + } => Err(cost_context.result(error.into())), + _ => Ok(()), + } + }) } fn on_subgraph_request( &self, request: &crate::services::subgraph::Request, ) -> Result<(), DemandControlError> { - let mut extensions = request.context.extensions().lock(); - let cost_context = extensions.get_or_default_mut::<CostContext>(); - match self { - Test { - stage: TestStage::SubgraphRequest, - error, - } => Err(cost_context.result(error.into())), - _ => Ok(()), - } + request.context.extensions().with_lock(|mut lock| { + let cost_context = lock.get_or_default_mut::<CostContext>(); + match self { + Test { + stage: TestStage::SubgraphRequest, + error, + } => Err(cost_context.result(error.into())), + _ => Ok(()), + } + }) } fn on_subgraph_response( @@ -48,15 +50,16 @@ impl StrategyImpl for Test { _request: &ExecutableDocument, response: &Response, ) -> Result<(), DemandControlError> { - let mut extensions = response.context.extensions().lock(); - let cost_context = extensions.get_or_default_mut::<CostContext>(); - match self { - Test { - stage: TestStage::SubgraphResponse, - error, - } => Err(cost_context.result(error.into())), - _ => Ok(()), - } + response.context.extensions().with_lock(|mut lock| { + let cost_context = lock.get_or_default_mut::<CostContext>(); + match self { + Test { + stage: TestStage::SubgraphResponse, + error, + } => Err(cost_context.result(error.into())), + _ => Ok(()), + } + }) } fn on_execution_response( @@ -65,14 +68,15 @@ impl StrategyImpl for Test { _request: &ExecutableDocument, _response: &crate::graphql::Response, ) -> Result<(), DemandControlError> { - let mut extensions = context.extensions().lock(); - let cost_context = extensions.get_or_default_mut::<CostContext>(); - match self { - Test { - stage: TestStage::ExecutionResponse, - error, - } => Err(cost_context.result(error.into())), - _ => Ok(()), - } + context.extensions().with_lock(|mut lock| { + let cost_context = lock.get_or_default_mut::<CostContext>(); + match self { + Test { + stage: TestStage::ExecutionResponse, + error, + } => Err(cost_context.result(error.into())), + _ => Ok(()), + } + }) } } diff --git a/apollo-router/src/plugins/file_uploads/mod.rs b/apollo-router/src/plugins/file_uploads/mod.rs index a605a76e8e..61a373c735 100644 --- a/apollo-router/src/plugins/file_uploads/mod.rs +++ b/apollo-router/src/plugins/file_uploads/mod.rs @@ -176,7 +176,9 @@ async fn router_layer( let mut multipart = MultipartRequest::new(request_body, boundary, limits); let operations_stream = multipart.operations_field().await?; - req.context.extensions().lock().insert(multipart); + req.context + .extensions() + .with_lock(|mut lock| lock.insert(multipart)); let content_type = operations_stream .headers() @@ -202,9 +204,7 @@ async fn supergraph_layer(mut req: supergraph::Request) -> Result<supergraph::Re let multipart = req .context .extensions() - .lock() - .get::<MultipartRequest>() - .cloned(); + .with_lock(|lock| lock.get::<MultipartRequest>().cloned()); if let Some(mut multipart) = multipart { let map_field = multipart.map_field().await?; @@ -226,13 +226,12 @@ async fn supergraph_layer(mut req: supergraph::Request) -> Result<supergraph::Re } } - req.context - .extensions() - .lock() - .insert(SupergraphLayerResult { + req.context.extensions().with_lock(|mut lock| { + lock.insert(SupergraphLayerResult { multipart, map: Arc::new(map_field), - }); + }) + }); } Ok(req) } @@ -305,9 +304,7 @@ fn execution_layer(req: execution::Request) -> Result<execution::Request> { let supergraph_result = req .context .extensions() - .lock() - .get::<SupergraphLayerResult>() - .cloned(); + .with_lock(|lock| lock.get::<SupergraphLayerResult>().cloned()); if let Some(supergraph_result) = supergraph_result { let SupergraphLayerResult { map, .. } = supergraph_result; @@ -321,9 +318,7 @@ async fn subgraph_layer(mut req: subgraph::Request) -> subgraph::Request { let supergraph_result = req .context .extensions() - .lock() - .get::<SupergraphLayerResult>() - .cloned(); + .with_lock(|lock| lock.get::<SupergraphLayerResult>().cloned()); if let Some(supergraph_result) = supergraph_result { let SupergraphLayerResult { multipart, map } = supergraph_result; diff --git a/apollo-router/src/plugins/progressive_override/mod.rs b/apollo-router/src/plugins/progressive_override/mod.rs index d56627f3dd..bcbf462afd 100644 --- a/apollo-router/src/plugins/progressive_override/mod.rs +++ b/apollo-router/src/plugins/progressive_override/mod.rs @@ -195,7 +195,7 @@ impl Plugin for ProgressiveOverridePlugin { let crate::graphql::Request {query, operation_name, ..} = request.supergraph_request.body(); let operation_hash = hash_operation(query, operation_name); - let maybe_parsed_doc = request.context.extensions().lock().get::<ParsedDocument>().cloned(); + let maybe_parsed_doc = request.context.extensions().with_lock(|lock| lock.get::<ParsedDocument>().cloned()); if let Some(parsed_doc) = maybe_parsed_doc { // we have to visit the operation to find out which subset // of labels are relevant unless we've already cached that diff --git a/apollo-router/src/plugins/progressive_override/tests.rs b/apollo-router/src/plugins/progressive_override/tests.rs index b2d795483f..07b6bec7d8 100644 --- a/apollo-router/src/plugins/progressive_override/tests.rs +++ b/apollo-router/src/plugins/progressive_override/tests.rs @@ -145,8 +145,7 @@ async fn assert_expected_and_absent_labels_for_supergraph_service( let context = Context::new(); context .extensions() - .lock() - .insert::<ParsedDocument>(parsed_doc); + .with_lock(|mut lock| lock.insert::<ParsedDocument>(parsed_doc)); context .insert( @@ -218,8 +217,7 @@ async fn get_json_query_plan(query: &str) -> serde_json::Value { let context: Context = Context::new(); context .extensions() - .lock() - .insert::<ParsedDocument>(parsed_doc); + .with_lock(|mut lock| lock.insert::<ParsedDocument>(parsed_doc)); let request = supergraph::Request::fake_builder() .query(query) @@ -293,8 +291,7 @@ async fn query_with_labels(query: &str, labels_from_coprocessors: Vec<&str>) { let context = Context::new(); context .extensions() - .lock() - .insert::<ParsedDocument>(parsed_doc); + .with_lock(|mut lock| lock.insert::<ParsedDocument>(parsed_doc)); context .insert( diff --git a/apollo-router/src/plugins/record_replay/record.rs b/apollo-router/src/plugins/record_replay/record.rs index 59b36dd803..131dfae33e 100644 --- a/apollo-router/src/plugins/record_replay/record.rs +++ b/apollo-router/src/plugins/record_replay/record.rs @@ -103,7 +103,9 @@ impl Plugin for Record { let context = res.context.clone(); let after_complete = once(async move { - let recording = context.extensions().lock().remove::<Recording>(); + let recording = context + .extensions() + .with_lock(|mut lock| lock.remove::<Recording>()); if let Some(mut recording) = recording { let res_headers = externalize_header_map(&headers)?; @@ -169,12 +171,14 @@ impl Plugin for Record { let recording_enabled = if req.supergraph_request.headers().contains_key(RECORD_HEADER) { - req.context.extensions().lock().insert(Recording { - supergraph_sdl: supergraph_sdl.clone().to_string(), - client_request: Default::default(), - client_response: Default::default(), - formatted_query_plan: Default::default(), - subgraph_fetches: Default::default(), + req.context.extensions().with_lock(|mut lock| { + lock.insert(Recording { + supergraph_sdl: supergraph_sdl.clone().to_string(), + client_request: Default::default(), + client_response: Default::default(), + formatted_query_plan: Default::default(), + subgraph_fetches: Default::default(), + }) }); true } else { @@ -190,26 +194,29 @@ impl Plugin for Record { let method = req.supergraph_request.method().to_string(); let uri = req.supergraph_request.uri().to_string(); - if let Some(recording) = req.context.extensions().lock().get_mut::<Recording>() - { - recording.client_request = RequestDetails { - query, - operation_name, - variables, - headers, - method, - uri, - }; - } + req.context.extensions().with_lock(|mut lock| { + if let Some(recording) = lock.get_mut::<Recording>() { + recording.client_request = RequestDetails { + query, + operation_name, + variables, + headers, + method, + uri, + }; + } + }); } req }) .map_response(|res: supergraph::Response| { let context = res.context.clone(); res.map_stream(move |chunk| { - if let Some(recording) = context.extensions().lock().get_mut::<Recording>() { - recording.client_response.chunks.push(chunk.clone()); - } + context.extensions().with_lock(|mut lock| { + if let Some(recording) = lock.get_mut::<Recording>() { + recording.client_response.chunks.push(chunk.clone()); + } + }); chunk }) @@ -221,9 +228,12 @@ impl Plugin for Record { fn execution_service(&self, service: execution::BoxService) -> execution::BoxService { ServiceBuilder::new() .map_request(|req: execution::Request| { - if let Some(recording) = req.context.extensions().lock().get_mut::<Recording>() { - recording.formatted_query_plan = req.query_plan.formatted_query_plan.clone(); - } + req.context.extensions().with_lock(|mut lock| { + if let Some(recording) = lock.get_mut::<Recording>() { + recording.formatted_query_plan = + req.query_plan.formatted_query_plan.clone(); + } + }); req }) .service(service) @@ -276,17 +286,17 @@ impl Plugin for Record { request: req, }; - if let Some(recording) = - res.context.extensions().lock().get_mut::<Recording>() - { - if recording.subgraph_fetches.is_none() { - recording.subgraph_fetches = Some(Default::default()); - } + res.context.extensions().with_lock(|mut lock| { + if let Some(recording) = lock.get_mut::<Recording>() { + if recording.subgraph_fetches.is_none() { + recording.subgraph_fetches = Some(Default::default()); + } - if let Some(fetches) = &mut recording.subgraph_fetches { - fetches.insert(operation_name, subgraph); + if let Some(fetches) = &mut recording.subgraph_fetches { + fetches.insert(operation_name, subgraph); + } } - } + }); Ok(res) } Err(err) => Err(err), diff --git a/apollo-router/src/plugins/telemetry/config_new/cost/mod.rs b/apollo-router/src/plugins/telemetry/config_new/cost/mod.rs index b6676ac4ac..02e20da6a8 100644 --- a/apollo-router/src/plugins/telemetry/config_new/cost/mod.rs +++ b/apollo-router/src/plugins/telemetry/config_new/cost/mod.rs @@ -65,7 +65,9 @@ impl Selectors for SupergraphCostAttributes { fn on_response_event(&self, _response: &Self::EventResponse, ctx: &Context) -> Vec<KeyValue> { let mut attrs = Vec::with_capacity(4); - let cost_result = ctx.extensions().lock().get::<CostContext>().cloned(); + let cost_result = ctx + .extensions() + .with_lock(|lock| lock.get::<CostContext>().cloned()); if let Some(cost_result) = cost_result { if let Some(true) = self.cost_estimated { attrs.push(KeyValue::new("cost.estimated", cost_result.estimated)); @@ -377,14 +379,13 @@ mod test { fn make_request(instruments: &CostInstruments) { let context = Context::new(); - { - let mut extensions = context.extensions().lock(); - extensions.insert(CostContext::default()); - let cost_result = extensions.get_or_default_mut::<CostContext>(); + context.extensions().with_lock(|mut lock| { + lock.insert(CostContext::default()); + let cost_result = lock.get_or_default_mut::<CostContext>(); cost_result.estimated = 100.0; cost_result.actual = 10.0; cost_result.result = "COST_TOO_EXPENSIVE" - } + }); let _ = context.insert(OPERATION_NAME, "Test".to_string()).unwrap(); instruments.on_request( &supergraph::Request::fake_builder() diff --git a/apollo-router/src/plugins/telemetry/config_new/events.rs b/apollo-router/src/plugins/telemetry/config_new/events.rs index fdd0bb3e9d..9f5689ae42 100644 --- a/apollo-router/src/plugins/telemetry/config_new/events.rs +++ b/apollo-router/src/plugins/telemetry/config_new/events.rs @@ -301,8 +301,7 @@ impl Instrumented request .context .extensions() - .lock() - .insert(SupergraphEventResponseLevel(self.response)); + .with_lock(|mut lock| lock.insert(SupergraphEventResponseLevel(self.response))); } for custom_event in &self.custom { custom_event.on_request(request); @@ -345,15 +344,13 @@ impl Instrumented request .context .extensions() - .lock() - .insert(SubgraphEventRequestLevel(self.request)); + .with_lock(|mut lock| lock.insert(SubgraphEventRequestLevel(self.request))); } if self.response != EventLevel::Off { request .context .extensions() - .lock() - .insert(SubgraphEventResponseLevel(self.response)); + .with_lock(|mut lock| lock.insert(SubgraphEventResponseLevel(self.response))); } for custom_event in &self.custom { custom_event.on_request(request); diff --git a/apollo-router/src/plugins/telemetry/config_new/graphql/mod.rs b/apollo-router/src/plugins/telemetry/config_new/graphql/mod.rs index 39be3e1d57..8772b08351 100644 --- a/apollo-router/src/plugins/telemetry/config_new/graphql/mod.rs +++ b/apollo-router/src/plugins/telemetry/config_new/graphql/mod.rs @@ -455,7 +455,9 @@ pub(crate) mod test { crate::spec::Query::parse_document(query_str, None, &schema, &Configuration::default()) .unwrap(); let context = Context::new(); - context.extensions().lock().insert(query); + context + .extensions() + .with_lock(|mut lock| lock.insert(query)); context } diff --git a/apollo-router/src/plugins/telemetry/config_new/selectors.rs b/apollo-router/src/plugins/telemetry/config_new/selectors.rs index 04bc8c222a..4a38dd0c9d 100644 --- a/apollo-router/src/plugins/telemetry/config_new/selectors.rs +++ b/apollo-router/src/plugins/telemetry/config_new/selectors.rs @@ -854,17 +854,14 @@ impl Selector for SupergraphSelector { val.maybe_to_otel_value() } .or_else(|| default.maybe_to_otel_value()), - SupergraphSelector::Cost { cost } => { - let extensions = ctx.extensions().lock(); - extensions - .get::<CostContext>() - .map(|cost_result| match cost { - CostValue::Estimated => cost_result.estimated.into(), - CostValue::Actual => cost_result.actual.into(), - CostValue::Delta => cost_result.delta().into(), - CostValue::Result => cost_result.result.into(), - }) - } + SupergraphSelector::Cost { cost } => ctx.extensions().with_lock(|lock| { + lock.get::<CostContext>().map(|cost_result| match cost { + CostValue::Estimated => cost_result.estimated.into(), + CostValue::Actual => cost_result.actual.into(), + CostValue::Delta => cost_result.delta().into(), + CostValue::Result => cost_result.result.into(), + }) + }), SupergraphSelector::OnGraphQLError { on_graphql_error } if *on_graphql_error => { if ctx.get_json_value(CONTAINS_GRAPHQL_ERROR) == Some(serde_json_bytes::Value::Bool(true)) diff --git a/apollo-router/src/plugins/telemetry/mod.rs b/apollo-router/src/plugins/telemetry/mod.rs index b34a5e093c..33819433e4 100644 --- a/apollo-router/src/plugins/telemetry/mod.rs +++ b/apollo-router/src/plugins/telemetry/mod.rs @@ -457,17 +457,14 @@ impl Plugin for Telemetry { } } - if response - .context - .extensions() - .lock() - .get::<Arc<UsageReporting>>() - .map(|u| { - u.stats_report_key == "## GraphQLValidationFailure\n" - || u.stats_report_key == "## GraphQLParseFailure\n" - }) - .unwrap_or(false) - { + if response.context.extensions().with_lock(|lock| { + lock.get::<Arc<UsageReporting>>() + .map(|u| { + u.stats_report_key == "## GraphQLValidationFailure\n" + || u.stats_report_key == "## GraphQLParseFailure\n" + }) + .unwrap_or(false) + }) { Self::update_apollo_metrics( &response.context, field_level_instrumentation_ratio, @@ -523,12 +520,7 @@ impl Plugin for Telemetry { )) .map_response(move |mut resp: SupergraphResponse| { let config = config_map_res_first.clone(); - if let Some(usage_reporting) = { - let extensions = resp.context.extensions().lock(); - let urp = extensions.get::<Arc<UsageReporting>>(); - urp.cloned() - } - { + if let Some(usage_reporting) = resp.context.extensions().with_lock(|lock| lock.get::<Arc<UsageReporting>>().cloned()) { // Record the operation signature on the router span Span::current().record( APOLLO_PRIVATE_OPERATION_SIGNATURE.as_str(), @@ -945,21 +937,17 @@ impl Telemetry { custom_events: SupergraphEvents, custom_graphql_instruments: GraphQLInstruments, ) -> Result<SupergraphResponse, BoxError> { - let mut metric_attrs = { - context - .extensions() - .lock() - .get::<MetricsAttributes>() - .cloned() - } - .map(|attrs| { - attrs - .0 - .into_iter() - .map(|(attr_name, attr_value)| KeyValue::new(attr_name, attr_value)) - .collect::<Vec<KeyValue>>() - }) - .unwrap_or_default(); + let mut metric_attrs = context + .extensions() + .with_lock(|lock| lock.get::<MetricsAttributes>().cloned()) + .map(|attrs| { + attrs + .0 + .into_iter() + .map(|(attr_name, attr_value)| KeyValue::new(attr_name, attr_value)) + .collect::<Vec<KeyValue>>() + }) + .unwrap_or_default(); let res = match result { Ok(response) => { metric_attrs.push(KeyValue::new( @@ -1083,10 +1071,11 @@ impl Telemetry { let _ = context .extensions() - .lock() - .insert(MetricsAttributes(attributes)); + .with_lock(|mut lock| lock.insert(MetricsAttributes(attributes))); if rand::thread_rng().gen_bool(field_level_instrumentation_ratio) { - context.extensions().lock().insert(EnableSubgraphFtv1); + context + .extensions() + .with_lock(|mut lock| lock.insert(EnableSubgraphFtv1)); } } @@ -1144,8 +1133,8 @@ impl Telemetry { sub_request .context .extensions() - .lock() - .insert(SubgraphMetricsAttributes(attributes)); //.unwrap(); + .with_lock(|mut lock| lock.insert(SubgraphMetricsAttributes(attributes))); + //.unwrap(); } #[allow(clippy::too_many_arguments)] @@ -1156,21 +1145,17 @@ impl Telemetry { now: Instant, result: &Result<Response, BoxError>, ) { - let mut metric_attrs = { - context - .extensions() - .lock() - .get::<SubgraphMetricsAttributes>() - .cloned() - } - .map(|attrs| { - attrs - .0 - .into_iter() - .map(|(attr_name, attr_value)| KeyValue::new(attr_name, attr_value)) - .collect::<Vec<KeyValue>>() - }) - .unwrap_or_default(); + let mut metric_attrs = context + .extensions() + .with_lock(|lock| lock.get::<SubgraphMetricsAttributes>().cloned()) + .map(|attrs| { + attrs + .0 + .into_iter() + .map(|(attr_name, attr_value)| KeyValue::new(attr_name, attr_value)) + .collect::<Vec<KeyValue>>() + }) + .unwrap_or_default(); metric_attrs.push(subgraph_attribute); // Fill attributes from context metric_attrs.extend( @@ -1370,11 +1355,10 @@ impl Telemetry { operation_kind: OperationKind, operation_subtype: Option<OperationSubType>, ) { - let metrics = if let Some(usage_reporting) = { - let lock = context.extensions().lock(); - let urp = lock.get::<Arc<UsageReporting>>(); - urp.cloned() - } { + let metrics = if let Some(usage_reporting) = context + .extensions() + .with_lock(|lock| lock.get::<Arc<UsageReporting>>().cloned()) + { let licensed_operation_count = licensed_operation_count(&usage_reporting.stats_report_key); let persisted_query_hit = context @@ -1833,8 +1817,7 @@ fn request_ftv1(mut req: SubgraphRequest) -> SubgraphRequest { if req .context .extensions() - .lock() - .contains_key::<EnableSubgraphFtv1>() + .with_lock(|lock| lock.contains_key::<EnableSubgraphFtv1>()) && Span::current().context().span().span_context().is_sampled() { req.subgraph_request @@ -1849,8 +1832,7 @@ fn store_ftv1(subgraph_name: &ByteString, resp: SubgraphResponse) -> SubgraphRes if resp .context .extensions() - .lock() - .contains_key::<EnableSubgraphFtv1>() + .with_lock(|lock| lock.contains_key::<EnableSubgraphFtv1>()) { if let Some(serde_json_bytes::Value::String(ftv1)) = resp.response.body().extensions.get("ftv1") diff --git a/apollo-router/src/plugins/traffic_shaping/deduplication.rs b/apollo-router/src/plugins/traffic_shaping/deduplication.rs index bae3f620bc..30c6293eb5 100644 --- a/apollo-router/src/plugins/traffic_shaping/deduplication.rs +++ b/apollo-router/src/plugins/traffic_shaping/deduplication.rs @@ -81,8 +81,7 @@ where if request .context .extensions() - .lock() - .contains_key::<BatchQuery>() + .with_lock(|lock| lock.contains_key::<BatchQuery>()) { return service.ready_oneshot().await?.call(request).await; } diff --git a/apollo-router/src/query_planner/bridge_query_planner.rs b/apollo-router/src/query_planner/bridge_query_planner.rs index 873dd5e399..197f320012 100644 --- a/apollo-router/src/query_planner/bridge_query_planner.rs +++ b/apollo-router/src/query_planner/bridge_query_planner.rs @@ -735,13 +735,13 @@ impl Service<QueryPlannerRequest> for BridgeQueryPlanner { let metadata = context .extensions() - .lock() - .get::<CacheKeyMetadata>() - .cloned() - .unwrap_or_default(); + .with_lock(|lock| lock.get::<CacheKeyMetadata>().cloned().unwrap_or_default()); let this = self.clone(); let fut = async move { - let mut doc = match context.extensions().lock().get::<ParsedDocument>().cloned() { + let mut doc = match context + .extensions() + .with_lock(|lock| lock.get::<ParsedDocument>().cloned()) + { None => return Err(QueryPlannerError::SpecError(SpecError::UnknownFileId)), Some(d) => d, }; @@ -772,8 +772,7 @@ impl Service<QueryPlannerRequest> for BridgeQueryPlanner { }); context .extensions() - .lock() - .insert::<ParsedDocument>(doc.clone()); + .with_lock(|mut lock| lock.insert::<ParsedDocument>(doc.clone())); } } @@ -805,16 +804,17 @@ impl Service<QueryPlannerRequest> for BridgeQueryPlanner { Err(e) => { match &e { QueryPlannerError::PlanningErrors(pe) => { - context - .extensions() - .lock() - .insert(Arc::new(pe.usage_reporting.clone())); + context.extensions().with_lock(|mut lock| { + lock.insert(Arc::new(pe.usage_reporting.clone())) + }); } QueryPlannerError::SpecError(e) => { - context.extensions().lock().insert(Arc::new(UsageReporting { - stats_report_key: e.get_error_key().to_string(), - referenced_fields_by_type: HashMap::new(), - })); + context.extensions().with_lock(|mut lock| { + lock.insert(Arc::new(UsageReporting { + stats_report_key: e.get_error_key().to_string(), + referenced_fields_by_type: HashMap::new(), + })) + }); } _ => (), } diff --git a/apollo-router/src/query_planner/caching_query_planner.rs b/apollo-router/src/query_planner/caching_query_planner.rs index f6b04aba23..1540f83b4b 100644 --- a/apollo-router/src/query_planner/caching_query_planner.rs +++ b/apollo-router/src/query_planner/caching_query_planner.rs @@ -305,9 +305,10 @@ where query = modified_query.to_string(); } - context.extensions().lock().insert::<ParsedDocument>(doc); - - context.extensions().lock().insert(caching_key.metadata); + context.extensions().with_lock(|mut lock| { + lock.insert::<ParsedDocument>(doc); + lock.insert(caching_key.metadata) + }); let request = QueryPlannerRequest { query, @@ -379,11 +380,10 @@ where Box::pin(async move { let context = request.context.clone(); qp.plan(request).await.map(|response| { - if let Some(usage_reporting) = { - let lock = context.extensions().lock(); - let urp = lock.get::<Arc<UsageReporting>>(); - urp.cloned() - } { + if let Some(usage_reporting) = context + .extensions() + .with_lock(|lock| lock.get::<Arc<UsageReporting>>().cloned()) + { let _ = response.context.insert( APOLLO_OPERATION_ID, stats_report_key_hash(usage_reporting.stats_report_key.as_str()), @@ -426,7 +426,11 @@ where .unwrap_or_default(), }; - let doc = match request.context.extensions().lock().get::<ParsedDocument>() { + let doc = match request + .context + .extensions() + .with_lock(|lock| lock.get::<ParsedDocument>().cloned()) + { None => { return Err(CacheResolverError::RetrievalError(Arc::new( // TODO: dedicated error variant? @@ -438,11 +442,11 @@ where Some(d) => d.clone(), }; - let metadata = { - let lock = request.context.extensions().lock(); - let ckm = lock.get::<CacheKeyMetadata>().cloned(); - ckm.unwrap_or_default() - }; + let metadata = request + .context + .extensions() + .with_lock(|lock| lock.get::<CacheKeyMetadata>().cloned()) + .unwrap_or_default(); let caching_key = CachingQueryKey { query: request.query.clone(), @@ -502,10 +506,9 @@ where // This will be overridden when running in ApolloMetricsGenerationMode::New mode if let Some(QueryPlannerContent::Plan { plan, .. }) = &content { - context - .extensions() - .lock() - .insert::<Arc<UsageReporting>>(plan.usage_reporting.clone()); + context.extensions().with_lock(|mut lock| { + lock.insert::<Arc<UsageReporting>>(plan.usage_reporting.clone()) + }); } Ok(QueryPlannerResponse { content, @@ -540,10 +543,9 @@ where match res { Ok(content) => { if let QueryPlannerContent::Plan { plan, .. } = &content { - context - .extensions() - .lock() - .insert::<Arc<UsageReporting>>(plan.usage_reporting.clone()); + context.extensions().with_lock(|mut lock| { + lock.insert::<Arc<UsageReporting>>(plan.usage_reporting.clone()) + }); } Ok(QueryPlannerResponse::builder() @@ -554,23 +556,19 @@ where Err(error) => { match error.deref() { QueryPlannerError::PlanningErrors(pe) => { - request - .context - .extensions() - .lock() - .insert::<Arc<UsageReporting>>(Arc::new( + request.context.extensions().with_lock(|mut lock| { + lock.insert::<Arc<UsageReporting>>(Arc::new( pe.usage_reporting.clone(), - )); + )) + }); } QueryPlannerError::SpecError(e) => { - request - .context - .extensions() - .lock() - .insert::<Arc<UsageReporting>>(Arc::new(UsageReporting { + request.context.extensions().with_lock(|mut lock| { + lock.insert::<Arc<UsageReporting>>(Arc::new(UsageReporting { stats_report_key: e.get_error_key().to_string(), referenced_fields_by_type: HashMap::new(), - })); + })) + }); } _ => {} } @@ -745,7 +743,9 @@ mod tests { .unwrap(); let context = Context::new(); - context.extensions().lock().insert::<ParsedDocument>(doc1); + context + .extensions() + .with_lock(|mut lock| lock.insert::<ParsedDocument>(doc1)); for _ in 0..5 { assert!(planner @@ -766,7 +766,9 @@ mod tests { .unwrap(); let context = Context::new(); - context.extensions().lock().insert::<ParsedDocument>(doc2); + context + .extensions() + .with_lock(|mut lock| lock.insert::<ParsedDocument>(doc2)); assert!(planner .call(query_planner::CachingRequest::new( @@ -836,7 +838,9 @@ mod tests { .unwrap(); let context = Context::new(); - context.extensions().lock().insert::<ParsedDocument>(doc); + context + .extensions() + .with_lock(|mut lock| lock.insert::<ParsedDocument>(doc)); for _ in 0..5 { assert!(planner @@ -849,8 +853,7 @@ mod tests { .unwrap() .context .extensions() - .lock() - .contains_key::<Arc<UsageReporting>>()); + .with_lock(|lock| lock.contains_key::<Arc<UsageReporting>>())); } } diff --git a/apollo-router/src/query_planner/execution.rs b/apollo-router/src/query_planner/execution.rs index dc1e27123e..0b658b7bb3 100644 --- a/apollo-router/src/query_planner/execution.rs +++ b/apollo-router/src/query_planner/execution.rs @@ -230,9 +230,7 @@ impl PlanNode { if parameters .context .extensions() - .lock() - .get::<CanceledRequest>() - .is_some() + .with_lock(|lock| lock.get::<CanceledRequest>().is_some()) { value = Value::Object(Object::default()); errors = Vec::new(); diff --git a/apollo-router/src/services/http/service.rs b/apollo-router/src/services/http/service.rs index ff8b9c984d..fa01b84cce 100644 --- a/apollo-router/src/services/http/service.rs +++ b/apollo-router/src/services/http/service.rs @@ -294,9 +294,7 @@ impl tower::Service<HttpRequest> for HttpClientService { let signing_params = context .extensions() - .lock() - .get::<Arc<SigningParamsConfig>>() - .cloned(); + .with_lock(|lock| lock.get::<Arc<SigningParamsConfig>>().cloned()); Box::pin(async move { let http_request = if let Some(signing_params) = signing_params { diff --git a/apollo-router/src/services/layers/allow_only_http_post_mutations.rs b/apollo-router/src/services/layers/allow_only_http_post_mutations.rs index c44f20ad81..2b56660856 100644 --- a/apollo-router/src/services/layers/allow_only_http_post_mutations.rs +++ b/apollo-router/src/services/layers/allow_only_http_post_mutations.rs @@ -51,9 +51,7 @@ where let doc = match req .context .extensions() - .lock() - .get::<ParsedDocument>() - .cloned() + .with_lock(|lock| lock.get::<ParsedDocument>().cloned()) { None => { let errors = vec![Error::builder() @@ -286,14 +284,13 @@ mod forbid_http_get_mutations_tests { let (_schema, executable) = ast.to_mixed_validate().unwrap(); let context = Context::new(); - context - .extensions() - .lock() - .insert::<ParsedDocument>(Arc::new(ParsedDocumentInner { + context.extensions().with_lock(|mut lock| { + lock.insert::<ParsedDocument>(Arc::new(ParsedDocumentInner { ast, executable: Arc::new(executable), hash: Default::default(), - })); + })) + }); SupergraphRequest::fake_builder() .method(method) diff --git a/apollo-router/src/services/layers/apq.rs b/apollo-router/src/services/layers/apq.rs index 57d853b162..6912c5e28c 100644 --- a/apollo-router/src/services/layers/apq.rs +++ b/apollo-router/src/services/layers/apq.rs @@ -549,9 +549,11 @@ mod apq_tests { fn new_context() -> Context { let context = Context::new(); - context.extensions().lock().insert(ClientRequestAccepts { - json: true, - ..Default::default() + context.extensions().with_lock(|mut lock| { + lock.insert(ClientRequestAccepts { + json: true, + ..Default::default() + }) }); context diff --git a/apollo-router/src/services/layers/content_negotiation.rs b/apollo-router/src/services/layers/content_negotiation.rs index f07cb38972..e116b91070 100644 --- a/apollo-router/src/services/layers/content_negotiation.rs +++ b/apollo-router/src/services/layers/content_negotiation.rs @@ -93,7 +93,9 @@ where || accepts.multipart_subscription || accepts.json { - req.context.extensions().lock().insert(accepts); + req.context + .extensions() + .with_lock(|mut lock| lock.insert(accepts)); Ok(ControlFlow::Continue(req)) } else { @@ -143,11 +145,11 @@ where json: accepts_json, multipart_defer: accepts_multipart_defer, multipart_subscription: accepts_multipart_subscription, - } = { - let lock = context.extensions().lock(); - let cra = lock.get::<ClientRequestAccepts>(); - cra.cloned().unwrap_or_default() - }; + } = context.extensions().with_lock(|lock| { + lock.get::<ClientRequestAccepts>() + .cloned() + .unwrap_or_default() + }); if !res.has_next.unwrap_or_default() && (accepts_json || accepts_wildcard) { parts diff --git a/apollo-router/src/services/layers/persisted_queries/mod.rs b/apollo-router/src/services/layers/persisted_queries/mod.rs index 16f8d7b04c..e4f8e1e68b 100644 --- a/apollo-router/src/services/layers/persisted_queries/mod.rs +++ b/apollo-router/src/services/layers/persisted_queries/mod.rs @@ -121,8 +121,7 @@ impl PersistedQueryLayer { request .context .extensions() - .lock() - .insert(UsedQueryIdFromManifest); + .with_lock(|mut lock| lock.insert(UsedQueryIdFromManifest)); tracing::info!(monotonic_counter.apollo.router.operations.persisted_queries = 1u64); Ok(request) } else if manifest_poller.augmenting_apq_with_pre_registration_and_no_safelisting() { @@ -163,18 +162,21 @@ impl PersistedQueryLayer { }; let doc = { - let context_guard = request.context.extensions().lock(); - - if context_guard.get::<UsedQueryIdFromManifest>().is_some() { - // We got this operation from the manifest, so there's no - // need to check the safelist. - drop(context_guard); + if request + .context + .extensions() + .with_lock(|lock| lock.get::<UsedQueryIdFromManifest>().is_some()) + { return Ok(request); } - match context_guard.get::<ParsedDocument>() { + let doc_opt = request + .context + .extensions() + .with_lock(|lock| lock.get::<ParsedDocument>().cloned()); + + match doc_opt { None => { - drop(context_guard); // For some reason, QueryAnalysisLayer didn't give us a document? return Err(supergraph_err( graphql_err( @@ -186,7 +188,7 @@ impl PersistedQueryLayer { StatusCode::INTERNAL_SERVER_ERROR, )); } - Some(d) => d.clone(), + Some(d) => d, } }; diff --git a/apollo-router/src/services/layers/query_analysis.rs b/apollo-router/src/services/layers/query_analysis.rs index 70dc9c50b0..d3ea639853 100644 --- a/apollo-router/src/services/layers/query_analysis.rs +++ b/apollo-router/src/services/layers/query_analysis.rs @@ -189,22 +189,19 @@ impl QueryAnalysisLayer { request .context .extensions() - .lock() - .insert::<ParsedDocument>(doc); + .with_lock(|mut lock| lock.insert::<ParsedDocument>(doc)); Ok(SupergraphRequest { supergraph_request: request.supergraph_request, context: request.context, }) } Err(errors) => { - request - .context - .extensions() - .lock() - .insert(Arc::new(UsageReporting { + request.context.extensions().with_lock(|mut lock| { + lock.insert(Arc::new(UsageReporting { stats_report_key: errors.get_error_key().to_string(), referenced_fields_by_type: HashMap::new(), - })); + })) + }); Err(SupergraphResponse::builder() .errors(errors.into_graphql_errors().unwrap_or_default()) .status_code(StatusCode::BAD_REQUEST) diff --git a/apollo-router/src/services/router/service.rs b/apollo-router/src/services/router/service.rs index e08bbfc5bc..786541af9d 100644 --- a/apollo-router/src/services/router/service.rs +++ b/apollo-router/src/services/router/service.rs @@ -261,9 +261,7 @@ impl RouterService { multipart_subscription: accepts_multipart_subscription, } = context .extensions() - .lock() - .get() - .cloned() + .with_lock(|lock| lock.get().cloned()) .unwrap_or_default(); let (mut parts, mut body) = response.into_parts(); @@ -271,9 +269,7 @@ impl RouterService { if context .extensions() - .lock() - .get::<CanceledRequest>() - .is_some() + .with_lock(|lock| lock.get::<CanceledRequest>().is_some()) { parts.status = StatusCode::from_u16(499) .expect("499 is not a standard status code but common enough"); @@ -700,7 +696,9 @@ impl RouterService { // If subgraph batching configuration exists and is enabled for any of our subgraphs, we create our shared batch details let shared_batch_details = (is_batch) .then(|| { - context.extensions().lock().insert(self.batching.clone()); + context + .extensions() + .with_lock(|mut lock| lock.insert(self.batching.clone())); self.batching.subgraph.as_ref() }) @@ -744,31 +742,32 @@ impl RouterService { new_context.extend(&context); let client_request_accepts_opt = context .extensions() - .lock() - .get::<ClientRequestAccepts>() - .cloned(); - // Sub-scope so that new_context_guard is dropped before pushing into the new - // SupergraphRequest - { - let mut new_context_guard = new_context.extensions().lock(); + .with_lock(|lock| lock.get::<ClientRequestAccepts>().cloned()); + // We are only going to insert a BatchQuery if Subgraph processing is enabled + let b_for_index_opt = if let Some(shared_batch_details) = &shared_batch_details { + Some( + Batch::query_for_index(shared_batch_details.clone(), index + 1).map_err( + |err| TranslateError { + status: StatusCode::INTERNAL_SERVER_ERROR, + error: "failed to create batch", + extension_code: "BATCHING_ERROR", + extension_details: format!("failed to create batch entry: {err}"), + }, + )?, + ) + } else { + None + }; + new_context.extensions().with_lock(|mut lock| { if let Some(client_request_accepts) = client_request_accepts_opt { - new_context_guard.insert(client_request_accepts); + lock.insert(client_request_accepts); } - new_context_guard.insert(self.batching.clone()); + lock.insert(self.batching.clone()); // We are only going to insert a BatchQuery if Subgraph processing is enabled - if let Some(shared_batch_details) = &shared_batch_details { - new_context_guard.insert( - Batch::query_for_index(shared_batch_details.clone(), index + 1).map_err( - |err| TranslateError { - status: StatusCode::INTERNAL_SERVER_ERROR, - error: "failed to create batch", - extension_code: "BATCHING_ERROR", - extension_details: format!("failed to create batch entry: {err}"), - }, - )?, - ); + if let Some(b_for_index) = b_for_index_opt { + lock.insert(b_for_index); } - } + }); results.push(SupergraphRequest { supergraph_request: new, context: new_context, @@ -776,14 +775,16 @@ impl RouterService { } if let Some(shared_batch_details) = shared_batch_details { - context.extensions().lock().insert( + let b_for_index = Batch::query_for_index(shared_batch_details, 0).map_err(|err| TranslateError { status: StatusCode::INTERNAL_SERVER_ERROR, error: "failed to create batch", extension_code: "BATCHING_ERROR", extension_details: format!("failed to create batch entry: {err}"), - })?, - ); + })?; + context + .extensions() + .with_lock(|mut lock| lock.insert(b_for_index)); } results.insert( diff --git a/apollo-router/src/services/subgraph_service.rs b/apollo-router/src/services/subgraph_service.rs index 3e484603fe..d5414f6418 100644 --- a/apollo-router/src/services/subgraph_service.rs +++ b/apollo-router/src/services/subgraph_service.rs @@ -528,9 +528,7 @@ async fn call_websocket( let signing_params = context .extensions() - .lock() - .get::<Arc<SigningParamsConfig>>() - .cloned(); + .with_lock(|lock| lock.get::<Arc<SigningParamsConfig>>().cloned()); let request = if let Some(signing_params) = signing_params { signing_params @@ -542,9 +540,7 @@ async fn call_websocket( let subgraph_request_event = context .extensions() - .lock() - .get::<SubgraphEventRequestLevel>() - .cloned(); + .with_lock(|lock| lock.get::<SubgraphEventRequestLevel>().cloned()); if let Some(level) = subgraph_request_event { let mut attrs = HashMap::with_capacity(5); attrs.insert( @@ -839,9 +835,7 @@ pub(crate) async fn process_batch( let subgraph_response_event = batch_context .extensions() - .lock() - .get::<SubgraphEventResponseLevel>() - .cloned(); + .with_lock(|lock| lock.get::<SubgraphEventResponseLevel>().cloned()); if let Some(level) = subgraph_response_event { let mut attrs = HashMap::with_capacity(5); attrs.insert( @@ -1104,17 +1098,12 @@ async fn call_http( // If we are processing a batch, then we'd like to park tasks here, but we can't park them whilst // we have the context extensions lock held. That would be very bad... // We grab the (potential) BatchQuery and then operate on it later - let opt_batch_query = { - let extensions_guard = context.extensions().lock(); - - // We need to make sure to remove the BatchQuery from the context as it holds a sender to - // the owning batch - extensions_guard - .get::<Batching>() + let opt_batch_query = context.extensions().with_lock(|lock| { + lock.get::<Batching>() .and_then(|batching_config| batching_config.batch_include(service_name).then_some(())) - .and_then(|_| extensions_guard.get::<BatchQuery>().cloned()) + .and_then(|_| lock.get::<BatchQuery>().cloned()) .and_then(|bq| (!bq.finished()).then_some(bq)) - }; + }); // If we have a batch query, then it's time for batching if let Some(query) = opt_batch_query { @@ -1201,9 +1190,7 @@ pub(crate) async fn call_single_http( let subgraph_request_event = context .extensions() - .lock() - .get::<SubgraphEventRequestLevel>() - .cloned(); + .with_lock(|lock| lock.get::<SubgraphEventRequestLevel>().cloned()); if let Some(level) = subgraph_request_event { let mut attrs = HashMap::with_capacity(5); attrs.insert( @@ -1240,9 +1227,7 @@ pub(crate) async fn call_single_http( let subgraph_response_event = context .extensions() - .lock() - .get::<SubgraphEventResponseLevel>() - .cloned(); + .with_lock(|lock| lock.get::<SubgraphEventResponseLevel>().cloned()); if let Some(level) = subgraph_response_event { let mut attrs = HashMap::with_capacity(5); attrs.insert( diff --git a/apollo-router/src/services/supergraph/service.rs b/apollo-router/src/services/supergraph/service.rs index 9a270c2f1c..c39b50d2d6 100644 --- a/apollo-router/src/services/supergraph/service.rs +++ b/apollo-router/src/services/supergraph/service.rs @@ -228,11 +228,10 @@ async fn service_call( let is_deferred = plan.is_deferred(operation_name.as_deref(), &variables); let is_subscription = plan.is_subscription(operation_name.as_deref()); - if let Some(batching) = { - let lock = context.extensions().lock(); - let batching = lock.get::<Batching>(); - batching.cloned() - } { + if let Some(batching) = context + .extensions() + .with_lock(|lock| lock.get::<Batching>().cloned()) + { if batching.enabled && (is_deferred || is_subscription) { let message = if is_deferred { "BATCHING_DEFER_UNSUPPORTED" @@ -254,7 +253,9 @@ async fn service_call( return Ok(response); } // Now perform query batch analysis - let batch_query_opt = context.extensions().lock().get::<BatchQuery>().cloned(); + let batch_query_opt = context + .extensions() + .with_lock(|lock| lock.get::<BatchQuery>().cloned()); if let Some(batch_query) = batch_query_opt { let query_hashes = plan.query_hashes(batching, operation_name.as_deref(), &variables)?; @@ -272,9 +273,7 @@ async fn service_call( .. } = context .extensions() - .lock() - .get() - .cloned() + .with_lock(|lock| lock.get().cloned()) .unwrap_or_default(); let mut subscription_tx = None; if (is_deferred && !accepts_multipart_defer) @@ -342,9 +341,7 @@ async fn service_call( let supergraph_response_event = context .extensions() - .lock() - .get::<SupergraphEventResponseLevel>() - .cloned(); + .with_lock(|lock| lock.get::<SupergraphEventResponseLevel>().cloned()); match supergraph_response_event { Some(level) => { let mut attrs = HashMap::with_capacity(4); @@ -444,9 +441,10 @@ async fn subscription_task( let mut subscription_handle = subscription_handle.clone(); let operation_signature = context .extensions() - .lock() - .get::<Arc<UsageReporting>>() - .map(|usage_reporting| usage_reporting.stats_report_key.clone()) + .with_lock(|lock| { + lock.get::<Arc<UsageReporting>>() + .map(|usage_reporting| usage_reporting.stats_report_key.clone()) + }) .unwrap_or_default(); let operation_name = context @@ -649,10 +647,9 @@ async fn plan_query( // tests will pass. // During a regular request, `ParsedDocument` is already populated during query analysis. // Some tests do populate the document, so we only do it if it's not already there. - if !{ - let lock = context.extensions().lock(); + if !context.extensions().with_lock(|lock| { lock.contains_key::<crate::services::layers::query_analysis::ParsedDocument>() - } { + }) { let doc = crate::spec::Query::parse_document( &query_str, operation_name.as_deref(), @@ -660,10 +657,9 @@ async fn plan_query( &Configuration::default(), ) .map_err(crate::error::QueryPlannerError::from)?; - context - .extensions() - .lock() - .insert::<crate::services::layers::query_analysis::ParsedDocument>(doc); + context.extensions().with_lock(|mut lock| { + lock.insert::<crate::services::layers::query_analysis::ParsedDocument>(doc) + }); } let qpr = planning diff --git a/apollo-router/src/services/supergraph/tests.rs b/apollo-router/src/services/supergraph/tests.rs index 5d0f113a5e..c200068299 100644 --- a/apollo-router/src/services/supergraph/tests.rs +++ b/apollo-router/src/services/supergraph/tests.rs @@ -1695,9 +1695,11 @@ async fn reconstruct_deferred_query_under_interface() { fn subscription_context() -> Context { let context = Context::new(); - context.extensions().lock().insert(ClientRequestAccepts { - multipart_subscription: true, - ..Default::default() + context.extensions().with_lock(|mut lock| { + lock.insert(ClientRequestAccepts { + multipart_subscription: true, + ..Default::default() + }) }); context @@ -1705,9 +1707,11 @@ fn subscription_context() -> Context { fn defer_context() -> Context { let context = Context::new(); - context.extensions().lock().insert(ClientRequestAccepts { - multipart_defer: true, - ..Default::default() + context.extensions().with_lock(|mut lock| { + lock.insert(ClientRequestAccepts { + multipart_defer: true, + ..Default::default() + }) }); context From b75a02ac0dd4c0d74ecfbca943c353681c3e2d59 Mon Sep 17 00:00:00 2001 From: Gary Pennington <gary@apollographql.com> Date: Thu, 6 Jun 2024 16:46:12 +0100 Subject: [PATCH 4/4] Make it clear that we prefer with_lock() to lock() In the comments for lock() --- apollo-router/src/context/extensions/sync.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apollo-router/src/context/extensions/sync.rs b/apollo-router/src/context/extensions/sync.rs index bc35c0b5a8..a80d255abc 100644 --- a/apollo-router/src/context/extensions/sync.rs +++ b/apollo-router/src/context/extensions/sync.rs @@ -24,6 +24,8 @@ impl ExtensionsMutex { /// It is CRITICAL to avoid holding on to the mutex guard for too long, particularly across async calls. /// Doing so may cause performance degradation or even deadlocks. /// + /// DEPRECATED: prefer with_lock() + /// /// See related clippy lint for examples: <https://rust-lang.github.io/rust-clippy/master/index.html#/await_holding_lock> #[deprecated] pub fn lock(&self) -> ExtensionsGuard {