diff --git a/apollo-router-core/src/context.rs b/apollo-router-core/src/context.rs index 9afd827e41..3e9fe58b51 100644 --- a/apollo-router-core/src/context.rs +++ b/apollo-router-core/src/context.rs @@ -4,7 +4,7 @@ use futures::Future; use std::sync::Arc; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Context>> { /// Original request to the Router. pub request: T, diff --git a/apollo-router-core/src/error.rs b/apollo-router-core/src/error.rs index 227c997e50..21a1964444 100644 --- a/apollo-router-core/src/error.rs +++ b/apollo-router-core/src/error.rs @@ -12,7 +12,7 @@ use tracing::level_filters::LevelFilter; /// /// Note that these are not actually returned to the client, but are instead converted to JSON for /// [`struct@Error`]. -#[derive(Error, Display, Debug, Serialize, Deserialize)] +#[derive(Error, Display, Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] #[ignore_extra_doc_attributes] pub enum FetchError { diff --git a/apollo-router-core/src/layers/deduplication.rs b/apollo-router-core/src/layers/deduplication.rs new file mode 100644 index 0000000000..57cd140967 --- /dev/null +++ b/apollo-router-core/src/layers/deduplication.rs @@ -0,0 +1,130 @@ +use crate::{fetch::OperationKind, http_compat, Request, SubgraphRequest, SubgraphResponse}; +use futures::{future::BoxFuture, lock::Mutex}; +use std::{collections::HashMap, sync::Arc, task::Poll}; +use tokio::sync::broadcast::{self, Sender}; +use tower::{BoxError, Layer, ServiceExt}; + +pub struct QueryDeduplicationLayer; + +impl Layer for QueryDeduplicationLayer +where + S: tower::Service + Clone, +{ + type Service = QueryDeduplicationService; + + fn layer(&self, service: S) -> Self::Service { + QueryDeduplicationService::new(service) + } +} + +type WaitMap = + Arc, Sender>>>>; + +pub struct QueryDeduplicationService { + service: S, + wait_map: WaitMap, +} + +impl QueryDeduplicationService +where + S: tower::Service + Clone, +{ + fn new(service: S) -> Self { + QueryDeduplicationService { + service, + wait_map: Arc::new(Mutex::new(HashMap::new())), + } + } + + async fn dedup( + service: S, + wait_map: WaitMap, + request: SubgraphRequest, + ) -> Result { + loop { + let mut locked_wait_map = wait_map.lock().await; + match locked_wait_map.get_mut(&request.http_request) { + Some(waiter) => { + // Register interest in key + let mut receiver = waiter.subscribe(); + drop(locked_wait_map); + + match receiver.recv().await { + Ok(value) => { + return value + .map(|response| SubgraphResponse { + response: response.response, + context: request.context, + }) + .map_err(|e| e.into()) + } + // there was an issue with the broadcast channel, retry fetching + Err(_) => continue, + } + } + None => { + let (tx, _rx) = broadcast::channel(1); + locked_wait_map.insert(request.http_request.clone(), tx.clone()); + drop(locked_wait_map); + + let context = request.context.clone(); + let http_request = request.http_request.clone(); + let res = service.ready_oneshot().await?.call(request).await; + + { + let mut locked_wait_map = wait_map.lock().await; + locked_wait_map.remove(&http_request); + } + + // Let our waiters know + let broadcast_value = res + .as_ref() + .map(|response| response.clone()) + .map_err(|e| e.to_string()); + + // Our use case is very specific, so we are sure that + // we won't get any errors here. + tokio::task::spawn_blocking(move || { + tx.send(broadcast_value) + .expect("there is always at least one receiver alive, the _rx guard; qed") + }).await + .expect("can only fail if the task is aborted or if the internal code panics, neither is possible here; qed"); + + return res.map(|response| SubgraphResponse { + response: response.response, + context, + }); + } + } + } + } +} + +impl tower::Service for QueryDeduplicationService +where + S: tower::Service + + Clone + + Send + + 'static, + >::Future: Send + 'static, +{ + type Response = SubgraphResponse; + type Error = BoxError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: SubgraphRequest) -> Self::Future { + let mut service = self.service.clone(); + + if request.operation_kind == OperationKind::Query { + let wait_map = self.wait_map.clone(); + + Box::pin(async move { Self::dedup(service, wait_map, request).await }) + } else { + Box::pin(async move { service.call(request).await }) + } + } +} diff --git a/apollo-router-core/src/layers/headers.rs b/apollo-router-core/src/layers/headers.rs index 28c9617a96..04350dd286 100644 --- a/apollo-router-core/src/layers/headers.rs +++ b/apollo-router-core/src/layers/headers.rs @@ -271,6 +271,7 @@ mod test { use tower::{BoxError, Layer}; use tower::{Service, ServiceExt}; + use crate::fetch::OperationKind; use crate::headers::{ InsertConfig, InsertLayer, PropagateConfig, PropagateLayer, RemoveConfig, RemoveLayer, }; @@ -471,6 +472,7 @@ mod test { .body(Request::builder().query("query").build()) .unwrap() .into(), + operation_kind: OperationKind::Query, context: example_originating_request(), } } diff --git a/apollo-router-core/src/layers/mod.rs b/apollo-router-core/src/layers/mod.rs index 493f3bc7f3..c45fc3989b 100644 --- a/apollo-router-core/src/layers/mod.rs +++ b/apollo-router-core/src/layers/mod.rs @@ -1,3 +1,4 @@ pub mod cache; +pub mod deduplication; pub mod forbid_http_get_mutations; pub mod headers; diff --git a/apollo-router-core/src/query_planner/mod.rs b/apollo-router-core/src/query_planner/mod.rs index bfa4e0bbea..8955f7eb43 100644 --- a/apollo-router-core/src/query_planner/mod.rs +++ b/apollo-router-core/src/query_planner/mod.rs @@ -265,9 +265,9 @@ pub(crate) mod fetch { operation_kind: OperationKind, } - #[derive(Debug, PartialEq, Deserialize)] + #[derive(Copy, Clone, Debug, PartialEq, Deserialize)] #[serde(rename_all = "camelCase")] - pub(crate) enum OperationKind { + pub enum OperationKind { Query, Mutation, Subscription, @@ -349,6 +349,7 @@ pub(crate) mod fetch { ) -> Result { let FetchNode { operation, + operation_kind, service_name, .. } = self; @@ -375,6 +376,7 @@ pub(crate) mod fetch { .unwrap() .into(), context: context.clone(), + operation_kind: *operation_kind, }; let service = service_registry diff --git a/apollo-router-core/src/query_planner/snapshots/apollo_router_core__query_planner__tests__query_plan_from_json.snap b/apollo-router-core/src/query_planner/snapshots/apollo_router_core__query_planner__tests__query_plan_from_json.snap index ee284151b1..45f25c4487 100644 --- a/apollo-router-core/src/query_planner/snapshots/apollo_router_core__query_planner__tests__query_plan_from_json.snap +++ b/apollo-router-core/src/query_planner/snapshots/apollo_router_core__query_planner__tests__query_plan_from_json.snap @@ -1,6 +1,6 @@ --- source: apollo-router-core/src/query_planner/mod.rs -assertion_line: 473 +assertion_line: 467 expression: query_plan --- diff --git a/apollo-router-core/src/request.rs b/apollo-router-core/src/request.rs index 717d033bfc..fc2760daec 100644 --- a/apollo-router-core/src/request.rs +++ b/apollo-router-core/src/request.rs @@ -10,7 +10,7 @@ use typed_builder::TypedBuilder; #[derive(Clone, Derivative, Serialize, Deserialize, TypedBuilder, Default)] #[serde(rename_all = "camelCase")] #[builder(field_defaults(setter(into)))] -#[derivative(Debug, PartialEq)] +#[derivative(Debug, PartialEq, Eq, Hash)] pub struct Request { /// The graphql query. pub query: String, diff --git a/apollo-router-core/src/services/http_compat.rs b/apollo-router-core/src/services/http_compat.rs index 348a557169..c237f75c82 100644 --- a/apollo-router-core/src/services/http_compat.rs +++ b/apollo-router-core/src/services/http_compat.rs @@ -1,6 +1,10 @@ //! wrapper typpes for Request and Response from the http crate to improve their usability -use std::ops::{Deref, DerefMut}; +use std::{ + cmp::PartialEq, + hash::Hash, + ops::{Deref, DerefMut}, +}; #[derive(Debug, Default)] pub struct Request { @@ -61,6 +65,52 @@ impl DerefMut for Request { } } +impl Hash for Request { + fn hash(&self, state: &mut H) { + self.inner.method().hash(state); + self.inner.version().hash(state); + self.inner.uri().hash(state); + // this assumes headers are in the same order + for (name, value) in self.inner.headers() { + name.hash(state); + value.hash(state); + } + self.inner.body().hash(state); + } +} + +impl PartialEq for Request { + fn eq(&self, other: &Self) -> bool { + let mut res = self.inner.method().eq(other.inner.method()) + && self.inner.version().eq(&other.inner.version()) + && self.inner.uri().eq(other.inner.uri()); + + if !res { + return false; + } + if self.inner.headers().len() != other.inner.headers().len() { + return false; + } + + // this assumes headers are in the same order + for ((name, value), (other_name, other_value)) in self + .inner + .headers() + .iter() + .zip(other.inner.headers().iter()) + { + res = name.eq(other_name) && value.eq(other_value); + if !res { + return false; + } + } + + self.inner.body().eq(other.inner.body()) + } +} + +impl Eq for Request {} + impl From> for Request { fn from(inner: http::Request) -> Self { Request { inner } diff --git a/apollo-router-core/src/services/mod.rs b/apollo-router-core/src/services/mod.rs index fc2262d218..a86087d0b4 100644 --- a/apollo-router-core/src/services/mod.rs +++ b/apollo-router-core/src/services/mod.rs @@ -1,5 +1,6 @@ pub use self::execution_service::*; pub use self::router_service::*; +use crate::fetch::OperationKind; use crate::layers::cache::CachingLayer; use crate::prelude::graphql::*; use moka::sync::Cache; @@ -84,9 +85,12 @@ pub struct SubgraphRequest { pub http_request: http_compat::Request, pub context: Context, + + pub operation_kind: OperationKind, } assert_impl_all!(SubgraphResponse: Send); +#[derive(Clone, Debug)] pub struct SubgraphResponse { pub response: http_compat::Response, diff --git a/apollo-router/src/reqwest_subgraph_service.rs b/apollo-router/src/reqwest_subgraph_service.rs index d905ecd5a2..78931ba6e3 100644 --- a/apollo-router/src/reqwest_subgraph_service.rs +++ b/apollo-router/src/reqwest_subgraph_service.rs @@ -2,6 +2,7 @@ use apollo_router_core::prelude::*; use futures::future::BoxFuture; use std::sync::Arc; use std::task::Poll; +use tower::BoxError; use tracing::Instrument; use typed_builder::TypedBuilder; @@ -40,7 +41,7 @@ impl ReqwestSubgraphService { impl tower::Service for ReqwestSubgraphService { type Response = graphql::SubgraphResponse; - type Error = tower::BoxError; + type Error = BoxError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { @@ -48,13 +49,13 @@ impl tower::Service for ReqwestSubgraphService { Poll::Ready(Ok(())) } - fn call( - &mut self, - graphql::SubgraphRequest { + fn call(&mut self, request: graphql::SubgraphRequest) -> Self::Future { + let graphql::SubgraphRequest { http_request, context, - }: graphql::SubgraphRequest, - ) -> Self::Future { + .. + } = request; + let http_client = self.http_client.clone(); let target_url = if http_request.uri() == "/" { self.url.clone() @@ -82,7 +83,14 @@ impl tower::Service for ReqwestSubgraphService { request.headers_mut().extend(headers.into_iter()); *request.version_mut() = version; - let response = http_client.execute(request).await?; + let response = http_client.execute(request).await.map_err(|err| { + tracing::error!(fetch_error = format!("{:?}", err).as_str()); + + graphql::FetchError::SubrequestHttpError { + service: service_name.clone(), + reason: err.to_string(), + } + })?; let body = response .bytes() .instrument(tracing::debug_span!("aggregate_response_data")) diff --git a/apollo-router/src/router_factory.rs b/apollo-router/src/router_factory.rs index 89b1e17bf3..4443b5b4ca 100644 --- a/apollo-router/src/router_factory.rs +++ b/apollo-router/src/router_factory.rs @@ -1,5 +1,6 @@ use crate::configuration::{Configuration, ConfigurationError}; use crate::reqwest_subgraph_service::ReqwestSubgraphService; +use apollo_router_core::deduplication::QueryDeduplicationLayer; use apollo_router_core::{ http_compat::{Request, Response}, PluggableRouterServiceBuilder, ResponseBody, RouterRequest, Schema, @@ -70,9 +71,9 @@ impl RouterServiceFactory for YamlRouterServiceFactory { let mut builder = PluggableRouterServiceBuilder::new(schema, buffer, dispatcher.clone()); for (name, subgraph) in &configuration.subgraphs { - let mut subgraph_service = BoxService::new(ReqwestSubgraphService::new( - name.to_string(), - subgraph.routing_url.clone(), + let dedup_layer = QueryDeduplicationLayer; + let mut subgraph_service = BoxService::new(dedup_layer.layer( + ReqwestSubgraphService::new(name.to_string(), subgraph.routing_url.clone()), )); for layer in &subgraph.layers {